diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index cb1b3fc63e01af3c297a8f950fb5fb3d63dfd509..1f208f3dd02f878e7be1523f8586c2caa3f1ddcd 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -236,7 +236,6 @@ class BaseFakeNumpyLinalgNamespace: return self._array_context.np.sum(abs(ary)**ord)**(1/ord) else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") - # }}} diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 20e208ff56953828c8e136697d1cce41c12c2098..c3fecf835aac9f27e7c87df05c5f9a501a9a0712 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -54,6 +54,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def __init__(self, queue, allocator=None): super().__init__() + self._force_device_scalars = True self.queue = queue self.allocator = allocator self.np = self._get_fake_numpy_namespace() diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index 2a086b7b67a0c44c061ad6e9f96ab6cded476f00..8f2816db58fe4b0a93a53542f2113c9cafbeea8a 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -51,8 +51,7 @@ def make_loopy_program(domains, statements, kernel_data=None, statements, kernel_data=kernel_data, options=_DEFAULT_LOOPY_OPTIONS, - # FIXME: Restore when https://github.com/inducer/loopy/pull/431 is merged - # default_offset=lp.auto, + default_offset=lp.auto, name=name, lang_version=MOST_RECENT_LANGUAGE_VERSION) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index a939b05c585713c291524b60b50f98284719b11b..9e938b4e57391fe56b73893d0456a14c52ce2b48 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -87,10 +87,24 @@ class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory): force_device_scalars = False +class _PytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): + force_device_scalars = False + + def __call__(self): + from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext + return PytatoPyOpenCLArrayContext(self.get_command_queue()) + + def __str__(self): + return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>" + % (self.device.name.strip(), + self.device.platform.name.strip())) + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { "pyopencl": _PyOpenCLArrayContextFactory, "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory, + "pytato-pyopencl": _PytatoPyOpenCLArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3e9ec2c1242cdd18af4ebe920e627159f74bc178..eb55fb431f62547070f2dfedd7f7dac7894efe4a 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts([ - "pyopencl", "pyopencl-deprecated", + "pyopencl", "pyopencl-deprecated", "pytato-pyopencl" ]) diff --git a/test/test_utils.py b/test/test_utils.py index 63cadc7c3b836604457570fa72429870f7c1f442..e40dda193a7795559ff6070bfd2a85b2e8060b15 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -24,13 +24,16 @@ THE SOFTWARE. """ from arraycontext import ( # noqa: F401 - pytest_generate_tests_for_array_contexts - as pytest_generate_tests, + pytest_generate_tests_for_array_contexts, _acf) import logging logger = logging.getLogger(__name__) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + "pyopencl", "pyopencl-deprecated", "pytato-pyopencl" + ]) + def test_pt_actx_key_stringification_uniqueness(): from arraycontext.impl.pytato.compile import _ary_container_key_stringifier