diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 8a1e027416cb7d7b509069b42cd1fd9009550ef1..e74a6aef2041c626f77b69c10e772c78ee77e2f2 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,6 +34,7 @@ THE SOFTWARE. from typing import Any, Callable, Dict, Sequence, Type, Union +from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -221,6 +222,26 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): return "<PytatoJAXArrayContext>" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "<NumpyArrayContext>" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -229,6 +250,7 @@ _ARRAY_CONTEXT_FACTORY_REGISTRY: \ "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3f06156b76faa5376adff66000a1890307ff7623..ffd7553d5b321665288213160d48bf89555eccd8 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -46,6 +46,7 @@ from arraycontext import ( ) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, @@ -97,6 +98,7 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts([ _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ])