diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3f7e7601d9f9995b2a2fd7044b99a0818771864a..93468468fe838c497d3f93962ded60d38fee8c45 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -21,6 +21,7 @@ Python 3 Nvidia Titan V: export PYOPENCL_TEST=nvi:titan build_py_project_in_venv pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html + pip install git+https://gitlab.tiker.net/kaushikcfd/pycuda.git@pure_scalar#egg=pycuda test_py_project tags: diff --git a/.pylintrc-local.yml b/.pylintrc-local.yml new file mode 100644 index 0000000000000000000000000000000000000000..debc66374959488165546792776366912f723171 --- /dev/null +++ b/.pylintrc-local.yml @@ -0,0 +1,3 @@ +- arg: ignored-modules + val: + - pycuda diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 06e0b96c5f661c40aaf92d08f6aa5851daa6ddf4..7a1a3f5ecff420ab2225c27a5bcfd3eda7435ad6 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -75,11 +75,13 @@ from .container.traversal import ( from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import (PytatoPyOpenCLArrayContext, PytatoJAXArrayContext) +from .impl.pycuda import PyCUDAArrayContext from .impl.jax import EagerJAXArrayContext from .pytest import ( PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPyCUDAArrayContextFactory, pytest_generate_tests_for_array_contexts, pytest_generate_tests_for_pyopencl_array_context) @@ -119,7 +121,7 @@ __all__ = ( "from_numpy", "to_numpy", "with_array_context", "outer", - "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", + "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", "PyCUDAArrayContext", "PytatoJAXArrayContext", "EagerJAXArrayContext", @@ -127,6 +129,7 @@ __all__ = ( "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", + "PytestPyCUDAArrayContextFactory", "pytest_generate_tests_for_array_contexts", "pytest_generate_tests_for_pyopencl_array_context" ) diff --git a/arraycontext/impl/pycuda/__init__.py b/arraycontext/impl/pycuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b18f964483114ac6fcc265c65c0bb828ff55e49 --- /dev/null +++ b/arraycontext/impl/pycuda/__init__.py @@ -0,0 +1,279 @@ +""" +.. currentmodule:: arraycontext +.. autoclass:: PyCUDAArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from warnings import warn +from typing import (Callable, Optional, Tuple, TYPE_CHECKING, Dict) + +import numpy as np + +from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike +from arraycontext.container.traversal import (rec_map_array_container, + with_array_context) +if TYPE_CHECKING: + import loopy as lp + +# {{{ PyCUDAArrayContext + + +class PyCUDAArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :class:`pycuda.gpuarray.GPUArray` instances + for its base array class. + + .. automethod:: __init__ + """ + + def __init__(self, allocator=None): + import pycuda + import pycuda.gpuarray as gpuarray + super().__init__() + if allocator is None: + self.allocator = pycuda.driver.mem_alloc + from warnings import warn + warn("PyCUDAArrayContext created without an allocator on a GPU. " + "This can lead to high numbers of memory allocations. " + "Please consider using a pycuda.autoinit. " + "Run with allocator=False to disable this warning.") + else: + self.allocator = allocator + self.array_types = (gpuarray.GPUArray,) + self._loopy_transform_cache: \ + Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + + def _get_fake_numpy_namespace(self): + from arraycontext.impl.pycuda.fake_numpy import PyCUDAFakeNumpyNamespace + return PyCUDAFakeNumpyNamespace(self) + + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: + + if allowed_types is None: + allowed_types = self.array_types + + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + # {{{ ArrayContext interface + + def empty(self, shape, dtype): + import pycuda.gpuarray as gpuarray + return gpuarray.empty(shape=shape, dtype=dtype, + allocator=self.allocator) + + def zeros(self, shape, dtype): + import pycuda.gpuarray as gpuarray + return gpuarray.zeros(shape=shape, dtype=dtype, + allocator=self.allocator) + + def zeros_like(self, array): + def _zeros_like(ary): + return self.zeros(ary.shape, ary.dtype) + + return self._rec_map_container(_zeros_like, array, default_scalar=0) + + def from_numpy(self, array): + import pycuda.gpuarray as gpuarray + + def _from_numpy(ary): + return gpuarray.to_gpu(ary, allocator=self.allocator) + + return with_array_context( + self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True), + actx=self) + + def to_numpy(self, array): + def _to_numpy(ary): + if np.isscalar(ary): + return ary + return ary.get() + return with_array_context( + self._rec_map_container(_to_numpy, array), + actx=None) + + def call_loopy(self, t_unit, **kwargs): + try: + import loopy as lp + t_unit.target = lp.PyCudaTarget() + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + evt, result = t_unit(**kwargs, allocator=self.allocator) + + if self._wait_event_queue_length is not False: + prg_name = t_unit.default_entrypoint.name + wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault( + prg_name, []) + + wait_event_queue.append(evt) + if len(wait_event_queue) > self._wait_event_queue_length: + wait_event_queue.pop(0).wait() + + return {name: ary for name, ary in result.items()} + # raise NotImplementedError( + # "Loopy doesn't support calling PyCUDA kenels (yet).") + + def freeze(self, array): + import pycuda.gpuarray as gpuarray + + def _rec_freeze(ary): + if isinstance(ary, gpuarray.GPUArray): + return ary + else: + raise TypeError(f"{type(self).__name__} cannot freeze" + f" arrays of type '{type(ary).__name__}'.") + + return with_array_context(self._rec_map_container(_rec_freeze, array), + actx=None) + + def thaw(self, array): + def _thaw(ary): + return ary + return with_array_context(array, actx=self) + + # }}} + + # {{{ transform_loopy_program + + def transform_loopy_program(self, t_unit): + from warnings import warn + warn("Using arraycontext.PyCUDArrayContext.transform_loopy_program " + "to transform a program. This is deprecated and will stop working " + "in 2022. Instead, subclass PyCUDAArrayContext and implement " + "the specific logic required to transform the program for your " + "package or application. Check higher-level packages " + "(e.g. meshmode), which may already have subclasses you may want " + "to build on.", + DeprecationWarning, stacklevel=2) + + # accommodate loopy with and without kernel callables + + import loopy as lp + t_unit.target = lp.PyCudaTarget() + default_entrypoint = t_unit.default_entrypoint + options = default_entrypoint.options + if not (options.return_dict and options.no_numpy): + raise ValueError("Loopy kernel passed to call_loopy must " + "have return_dict and no_numpy options set. " + "Did you use arraycontext.make_loopy_program " + "to create this kernel?") + + all_inames = default_entrypoint.all_inames() + # FIXME: This could be much smarter. + inner_iname = None + + # import with underscore to avoid DeprecationWarning + from arraycontext.metadata import _FirstAxisIsElementsTag + + if (len(default_entrypoint.instructions) == 1 + and isinstance(default_entrypoint.instructions[0], lp.Assignment) + and any(isinstance(tag, _FirstAxisIsElementsTag) + # FIXME: Firedrake branch lacks kernel tags + for tag in getattr(default_entrypoint, "tags", ()))): + stmt, = default_entrypoint.instructions + + out_inames = [v.name for v in stmt.assignee.index_tuple] + assert out_inames + outer_iname = out_inames[0] + if len(out_inames) >= 2: + inner_iname = out_inames[1] + + elif "iel" in all_inames: + outer_iname = "iel" + + if "idof" in all_inames: + inner_iname = "idof" + + elif "i0" in all_inames: + outer_iname = "i0" + + if "i1" in all_inames: + inner_iname = "i1" + + elif not all_inames: + # no loops, nothing to transform + return t_unit + + else: + raise RuntimeError( + "Unable to reason what outer_iname and inner_iname " + f"needs to be; all_inames is given as: {all_inames}" + ) + + if inner_iname is not None: + t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0") + t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"}) + + return t_unit + + def clone(self): + return type(self)() + + def tag(self, tags, array): + warn("Tagging in PyCUDAArrayContext is a noop", stacklevel=2) + return array + + def tag_axis(self, iaxis, tags, array): + warn("Tagging in PyCUDAArrayContext is a noop", stacklevel=2) + return array + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return False + + @property + def permits_advanced_indexing(self): + return False + +# }}} + +# vim: foldmethod=marker diff --git a/arraycontext/impl/pycuda/fake_numpy.py b/arraycontext/impl/pycuda/fake_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7d912469b3adbe941f0b61ae45b28f2cb9995b --- /dev/null +++ b/arraycontext/impl/pycuda/fake_numpy.py @@ -0,0 +1,211 @@ +""" +.. currentmodule:: arraycontext +.. autoclass:: PyCUDAArrayContext +""" +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from functools import partial, reduce +import operator +import numpy as np + +from arraycontext.fake_numpy import \ + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace +from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container.traversal import ( + rec_multimap_array_container, rec_map_array_container, + rec_map_reduce_array_container, + ) + +import pycuda.cumath as cumath + +try: + import pycuda.gpuarray as gpuarray +except ImportError: + pass + + +# {{{ fake numpy + +class PyCUDAFakeNumpyNamespace(BaseFakeNumpyNamespace): + + def _get_fake_numpy_linalg_namespace(self): + return _PyCUDAFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + return partial(rec_map_array_container, getattr(cumath, name)) + + # {{{ comparisons + # FIXME: This should be documentation, not a comment. + # These are here mainly because some arrays may choose to interpret + # equality comparison as a binary predicate of structural identity, + # i.e. more like "are you two equal", and not like numpy semantics. + # These operations provide access to numpy-style comparisons in that + # case. + + def all(self, a): + import pycuda.gpuarray as gpuarray + + def _all(ary): + if np.isscalar(ary): + return np.int8(all([ary])) + return gpuarray.all(ary) + + result = rec_map_reduce_array_container( + partial(reduce, partial(gpuarray.minimum)), + _all, + a) + return result + + def array_equal(self, x, y): + import pycuda.gpuarray as gpuarray + actx = self._array_context + + true = actx.from_numpy(np.int8(True)) + false = actx.from_numpy(np.int8(False)) + + def rec_equal(x, y): + if type(x) != type(y): + return false + + try: + iterable = zip(serialize_container(x), serialize_container(y)) + except NotAnArrayContainerError: + if x.shape != y.shape: + return false + else: + return gpuarray.all(x == y) + else: + return reduce( + partial(gpuarray.minimum), + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true + ) + + result = rec_equal(x, y) + return result + + def equal(self, x, y): + return rec_multimap_array_container(operator.eq, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(operator.ne, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(operator.gt, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(operator.ge, x, y) + + def less(self, x, y): + return rec_multimap_array_container(operator.lt, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(operator.le, x, y) + + def maximum(self, x, y): + return rec_multimap_array_container(gpuarray.maximum, x, y) + + def minimum(self, x, y): + return rec_multimap_array_container(gpuarray.minimum, x, y) + + def where(self, criterion, then, else_): + return rec_multimap_array_container(gpuarray.where, criterion, then, else_) + + def sum(self, a, dtype=None): + def _gpuarray_sum(ary): + if dtype not in [ary.dtype, None]: + raise NotImplementedError + + return gpuarray.sum(ary) + + return rec_map_reduce_array_container(sum, _gpuarray_sum, a) + + def min(self, a): + return rec_map_reduce_array_container( + partial(reduce, partial(gpuarray.minimum)), partial(gpuarray.min), a) + + def max(self, a): + return rec_map_reduce_array_container( + partial(reduce, partial(gpuarray.maximum)), partial(gpuarray.max), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: gpuarray.stack(arrays=args, axis=axis), + *arrays) + + def ones_like(self, ary): + return self.full_like(ary, 1) + + def full_like(self, ary, fill_value): + def _full_like(subary): + ones = self._array_context.empty_like(subary) + ones.fill(fill_value) + return ones + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: gpuarray.reshape(ary, newshape, order=order), + a) + + def concatenate(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: gpuarray.concatenate(arrays=args, axis=axis), + *arrays) + + def ravel(self, a, order="C"): + def _rec_ravel(a): + if order in "FC": + return gpuarray.reshape(a, -1, order=order) + elif order == "A": + if a.flags.f_contiguous: + return gpuarray.reshape(a, -1, order="F") + elif a.flags.c_contiguous: + return gpuarray.reshape(a, -1, order="C") + else: + raise ValueError("For `order='A'`, array should be either" + " F-contiguous or C-contiguous.") + elif order == "K": + raise NotImplementedError("PyCUDAArrayContext.np.ravel not " + "implemented for 'order=K'") + else: + raise ValueError("`order` can be one of 'F', 'C', 'A' or 'K'. " + f"(got {order})") + return rec_map_array_container(_rec_ravel, a) + + +# }}} + + +# {{{ fake np.linalg + +class _PyCUDAFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + pass + +# }}} + + +# vim: foldmethod=marker diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1eceb4973920ff67ed772989695d7862b8c4021c..e3808565faf5b3ec48f6ab0f62e7f30d3b8c813f 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -3,6 +3,7 @@ .. autoclass:: PytestArrayContextFactory .. autoclass:: PytestPyOpenCLArrayContextFactory +.. autoclass:: PytestPyCUDAArrayContextFactory .. autofunction:: pytest_generate_tests_for_array_contexts .. autofunction:: pytest_generate_tests_for_pyopencl_array_context @@ -150,6 +151,38 @@ class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory self.device.platform.name.strip())) +class PytestPyCUDAArrayContextFactory(PytestArrayContextFactory): + """ + .. automethod:: __init__ + .. automethod:: __call__ + """ + + @classmethod + def is_available(cls) -> bool: + try: + import pycuda # noqa: F401 + return True + except ImportError: + return False + + +class _PytestPyCUDAArrayContextFactory( + PytestPyCUDAArrayContextFactory): + + @property + def actx_class(self): + from arraycontext import PyCUDAArrayContext + return PyCUDAArrayContext + + def __call__(self): + import pycuda.autoinit # noqa: F401 + actx_class = self.actx_class() + return actx_class + + def __str__(self): + return "" + + class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): def __init__(self, *args, **kwargs): pass @@ -197,6 +230,7 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { + "pycuda": _PytestPyCUDAArrayContextFactory, "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, "pyopencl-deprecated": _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, diff --git a/doc/array_context.rst b/doc/array_context.rst index 85a6cc44d93bd66c7f9d436e00b29b0a5a0f14a6..f2c118f6e7451d203676cb0edde2ae5d8b9db26b 100644 --- a/doc/array_context.rst +++ b/doc/array_context.rst @@ -32,6 +32,11 @@ Array context based on :mod:`jax.numpy` .. automodule:: arraycontext.impl.jax +Array context :mod:`pycuda.gpuarray` +------------------------------------------------------------- + +.. automodule:: arraycontext.impl.pycuda + .. _numpy-coverage: :mod:`numpy` coverage diff --git a/doc/conf.py b/doc/conf.py index 84a6d64088562c86470a9f95cdc44f76f4f8f67a..58a7a4b01eb820107f9c3d7a0f5d946e25896de1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -20,6 +20,7 @@ intersphinx_mapping = { "https://documen.tician.de/pytools": None, "https://documen.tician.de/pymbolic": None, "https://documen.tician.de/pyopencl": None, + "https://documen.tician.de/pycuda": None, "https://documen.tician.de/pytato": None, "https://documen.tician.de/loopy": None, "https://documen.tician.de/meshmode": None, diff --git a/setup.cfg b/setup.cfg index b24271f5a07c13fe72c1de7a1fb2a953a19594db..6a7f5307564bbdbe06f3d9483ad082b20f0ceaae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,3 +30,6 @@ ignore_missing_imports = True [mypy-pyopencl.*] ignore_missing_imports = True + +[mypy-pycuda.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108e5d6cb63b083f2659a9aec51f8170ca43..99ee30a9e7e387658c75b422768448d485ef128b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -36,6 +36,7 @@ from arraycontext import ( PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, EagerJAXArrayContext, + PyCUDAArrayContext, ArrayContainer, to_numpy, tag_axes) from arraycontext import ( # noqa: F401 @@ -43,6 +44,7 @@ from arraycontext import ( # noqa: F401 ) from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, + _PytestPyCUDAArrayContextFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory) @@ -72,6 +74,16 @@ class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext): return t_unit +class _PyCUDAArrayContextForTests(PyCUDAArrayContext): + """Like :class:`PyCUDArrayContext`, but applies no program + transformations whatsoever. Only to be used for testing internal to + :mod:`arraycontext`. + """ + + def transform_loopy_program(self, t_unit): + return t_unit + + class _PyOpenCLArrayContextWithHostScalarsForTestsFactory( _PytestPyOpenCLArrayContextFactoryWithClass): actx_class = _PyOpenCLArrayContextForTests @@ -87,12 +99,18 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( actx_class = _PytatoPyOpenCLArrayContextForTests +class _PyCUDAArrayContextForTestsFactory( + _PytestPyCUDAArrayContextFactory): + actx_class = _PyCUDAArrayContextForTests + + pytest_generate_tests = pytest_generate_tests_for_array_contexts([ _PyOpenCLArrayContextForTestsFactory, _PyOpenCLArrayContextWithHostScalarsForTestsFactory, _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PyCUDAArrayContextForTestsFactory, ]) @@ -379,7 +397,11 @@ def assert_close_to_numpy_in_containers(actx, op, args): ]) def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): actx = actx_factory() - if not hasattr(actx.np, sym_name): + if ( + not hasattr(actx.np, sym_name) + or (isinstance(actx, PyCUDAArrayContext) + and (sym_name in ["arctan2", "arctan"])) + ): pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") ndofs = 512 @@ -720,6 +742,8 @@ def test_array_equal(actx_factory): ]) def test_array_context_einsum_array_manipulation(actx_factory, spec): actx = actx_factory() + if isinstance(actx, PyCUDAArrayContext): + pytest.skip("Waiting for loopy to be more capable") mat = actx.from_numpy(np.random.randn(10, 10)) res = actx.to_numpy(actx.einsum(spec, mat, @@ -735,6 +759,8 @@ def test_array_context_einsum_array_manipulation(actx_factory, spec): ]) def test_array_context_einsum_array_matmatprods(actx_factory, spec): actx = actx_factory() + if isinstance(actx, PyCUDAArrayContext): + pytest.skip("Waiting for loopy to be more capable") mat_a = actx.from_numpy(np.random.randn(5, 5)) mat_b = actx.from_numpy(np.random.randn(5, 5)) @@ -749,6 +775,8 @@ def test_array_context_einsum_array_matmatprods(actx_factory, spec): ]) def test_array_context_einsum_array_tripleprod(actx_factory, spec): actx = actx_factory() + if isinstance(actx, PyCUDAArrayContext): + pytest.skip("PyCUDAArrayContext.einsum not implemented.") mat_a = actx.from_numpy(np.random.randn(7, 5)) mat_b = actx.from_numpy(np.random.randn(5, 7)) @@ -1531,8 +1559,8 @@ def test_to_numpy_on_frozen_arrays(actx_factory): def test_tagging(actx_factory): actx = actx_factory() - if isinstance(actx, EagerJAXArrayContext): - pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, (EagerJAXArrayContext, PyCUDAArrayContext)): + pytest.skip(f"{type(actx).__name__} has no tagging support") from pytools.tag import Tag