diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 6c1f959dac15964b70dbdce04e039f986c9d3e30..1eceb4973920ff67ed772989695d7862b8c4021c 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,14 +34,18 @@ THE SOFTWARE. from typing import Any, Callable, Dict, Sequence, Type, Union -import pyopencl as cl from arraycontext.context import ArrayContext # {{{ array context factories class PytestArrayContextFactory: - pass + @classmethod + def is_available(cls) -> bool: + return True + + def __call__(self) -> ArrayContext: + raise NotImplementedError class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory): @@ -56,6 +60,14 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory): """ self.device = device + @classmethod + def is_available(cls) -> bool: + try: + import pyopencl # noqa: F401 + return True + except ImportError: + return False + def get_command_queue(self): # Get rid of leftovers from past tests. # CL implementations are surprisingly limited in how many @@ -66,14 +78,12 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory): from gc import collect collect() + import pyopencl as cl # On Intel CPU CL, existence of a command queue does not ensure that # the context survives. ctx = cl.Context([self.device]) return ctx, cl.CommandQueue(ctx) - def __call__(self) -> ArrayContext: - raise NotImplementedError - class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory): force_device_scalars = True @@ -107,8 +117,15 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( force_device_scalars = False -class _PytestPytatoPyOpenCLArrayContextFactory( - PytestPyOpenCLArrayContextFactory): +class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): + @classmethod + def is_available(cls) -> bool: + try: + import pyopencl # noqa: F401 + import pytato # noqa: F401 + return True + except ImportError: + return False @property def actx_class(self): @@ -137,6 +154,14 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): def __init__(self, *args, **kwargs): pass + @classmethod + def is_available(cls) -> bool: + try: + import jax # noqa: F401 + return True + except ImportError: + return False + def __call__(self): from arraycontext import EagerJAXArrayContext from jax.config import config @@ -151,6 +176,15 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): def __init__(self, *args, **kwargs): pass + @classmethod + def is_available(cls) -> bool: + try: + import jax # noqa: F401 + import pytato # noqa: F401 + return True + except ImportError: + return False + def __call__(self): from arraycontext import PytatoJAXArrayContext from jax.config import config @@ -254,9 +288,19 @@ def pytest_generate_tests_for_array_contexts( else: raise ValueError(f"unknown array contexts: {unknown_factories}") - unique_factories = set([ - _ARRAY_CONTEXT_FACTORY_REGISTRY.get(factory, factory) # type: ignore[misc] - for factory in unique_factories]) + available_factories = { + factory for key in unique_factories + for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY.get(key, key)] + if ( + not isinstance(factory, str) + and issubclass(factory, PytestArrayContextFactory) + and factory.is_available()) + } + + from pytools import partition + pyopencl_factories, other_factories = partition( + lambda factory: issubclass(factory, PytestPyOpenCLArrayContextFactory), + available_factories) # }}} @@ -271,6 +315,7 @@ def pytest_generate_tests_for_array_contexts( return arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values() + empty_arg_dict = {k: None for k in arg_values[0]} # }}} @@ -283,23 +328,29 @@ def pytest_generate_tests_for_array_contexts( "'ctx_factory' / 'ctx_getter' as arguments.") arg_values_with_actx = [] - for arg_dict in arg_values: + + if pyopencl_factories: + for arg_dict in arg_values: + arg_values_with_actx.extend([ + {factory_arg_name: factory(arg_dict["device"]), **arg_dict} + for factory in pyopencl_factories + ]) + + if other_factories: arg_values_with_actx.extend([ - {factory_arg_name: factory(arg_dict["device"]), **arg_dict} - for factory in unique_factories + {factory_arg_name: factory(), **empty_arg_dict} + for factory in other_factories ]) else: arg_values_with_actx = arg_values - arg_value_tuples = [ - tuple(arg_dict[name] for name in arg_names) - for arg_dict in arg_values_with_actx - ] - # }}} - # Sort the actx's so that parallel pytest works - arg_value_tuples = sorted(arg_value_tuples, key=lambda x: x.__str__()) + # NOTE: sorts the args so that parallel pytest works + arg_value_tuples = sorted([ + tuple([arg_dict[name] for name in arg_names]) + for arg_dict in arg_values_with_actx + ], key=lambda x: str(x)) metafunc.parametrize(arg_names, arg_value_tuples, ids=ids) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index 1d40ae393ee370ad0bb93b701c7a08eb1e064018..f4d132ca8223e80198bd774ba9fc822fe824ff08 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -27,7 +27,7 @@ from arraycontext import pytest_generate_tests_for_array_contexts from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory from pytools.tag import Tag - +import pytest import logging logger = logging.getLogger(__name__) @@ -79,10 +79,15 @@ class BazTag(Tag): def test_tags_preserved_after_freeze(actx_factory): + actx = actx_factory() + + from arraycontext.impl.pytato import _BasePytatoArrayContext + if not isinstance(actx, _BasePytatoArrayContext): + pytest.skip("only pytato-based array context are supported") + from numpy.random import default_rng rng = default_rng() - actx = actx_factory() foo = actx.thaw(actx.freeze( actx.from_numpy(rng.random((10, 4))) .tagged(FooTag()) @@ -100,7 +105,6 @@ if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) else: - from pytest import main - main([__file__]) + pytest.main([__file__]) # vim: fdm=marker