diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index e93a8b38bd8528d8719dfe818bd58c56214a3c66..6c1f959dac15964b70dbdce04e039f986c9d3e30 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -1,6 +1,7 @@ """ .. currentmodule:: arraycontext +.. autoclass:: PytestArrayContextFactory .. autoclass:: PytestPyOpenCLArrayContextFactory .. autofunction:: pytest_generate_tests_for_array_contexts @@ -39,7 +40,11 @@ from arraycontext.context import ArrayContext # {{{ array context factories -class PytestPyOpenCLArrayContextFactory: +class PytestArrayContextFactory: + pass + + +class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory): """ .. automethod:: __init__ .. automethod:: __call__ @@ -108,7 +113,9 @@ class _PytestPytatoPyOpenCLArrayContextFactory( @property def actx_class(self): from arraycontext import PytatoPyOpenCLArrayContext - return PytatoPyOpenCLArrayContext + actx_cls = PytatoPyOpenCLArrayContext + actx_cls.transform_loopy_program = lambda s, t_unit: t_unit + return actx_cls def __call__(self): # The ostensibly pointless assignment to *ctx* keeps the CL context alive @@ -126,18 +133,48 @@ class _PytestPytatoPyOpenCLArrayContextFactory( self.device.platform.name.strip())) +class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + pass + + def __call__(self): + from arraycontext import EagerJAXArrayContext + from jax.config import config + config.update("jax_enable_x64", True) + return EagerJAXArrayContext() + + def __str__(self): + return "" + + +class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + pass + + def __call__(self): + from arraycontext import PytatoJAXArrayContext + from jax.config import config + config.update("jax_enable_x64", True) + return PytatoJAXArrayContext() + + def __str__(self): + return "" + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ - Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { + Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, "pyopencl-deprecated": _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, - "pytato-pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, + "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, + "pytato:jax": _PytestPytatoJaxArrayContextFactory, + "eagerjax": _PytestEagerJaxArrayContextFactory, } def register_pytest_array_context_factory( name: str, - factory: Type[PytestPyOpenCLArrayContextFactory]) -> None: + factory: Type[PytestArrayContextFactory]) -> None: if name in _ARRAY_CONTEXT_FACTORY_REGISTRY: raise ValueError(f"factory '{name}' already exists") @@ -149,7 +186,7 @@ def register_pytest_array_context_factory( # {{{ pytest integration def pytest_generate_tests_for_array_contexts( - factories: Sequence[Union[str, Type[PytestPyOpenCLArrayContextFactory]]], *, + factories: Sequence[Union[str, Type[PytestArrayContextFactory]]], *, factory_arg_name: str = "actx_factory", ) -> Callable[[Any], None]: """Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`. diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 95666b09a9abd6b85ff81e3190ff06be16029572..6080c5772c417cf34372ca4b9d817fd1ea58e226 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -42,7 +42,9 @@ from arraycontext import ( # noqa: F401 pytest_generate_tests_for_array_contexts, ) from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, - _PytestPytatoPyOpenCLArrayContextFactory) + _PytestPytatoPyOpenCLArrayContextFactory, + _PytestEagerJaxArrayContextFactory, + _PytestPytatoJaxArrayContextFactory) import logging @@ -89,6 +91,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts([ _PyOpenCLArrayContextForTestsFactory, _PyOpenCLArrayContextWithHostScalarsForTestsFactory, _PytatoPyOpenCLArrayContextForTestsFactory, + _PytestEagerJaxArrayContextFactory, + _PytestPytatoJaxArrayContextFactory, ]) @@ -291,7 +295,6 @@ def assert_close_to_numpy_in_containers(actx, op, args): ("any", 1, np.float64), ("all", 1, np.float64), ("arctan", 1, np.float64), - ("atan", 1, np.float64), # float + complex ("sin", 1, np.float64),