From 97940584273f457d89dcab0dcc0b3cf376d8dc68 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Fri, 25 Jun 2021 17:18:55 -0500 Subject: [PATCH] small fixes --- arraycontext/fake_numpy.py | 1 - arraycontext/impl/pytato/__init__.py | 1 + arraycontext/loopy.py | 3 +-- arraycontext/pytest.py | 14 ++++++++++++++ test/test_arraycontext.py | 2 +- test/test_utils.py | 7 +++++-- 6 files changed, 22 insertions(+), 6 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index cb1b3fc..1f208f3 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -236,7 +236,6 @@ class BaseFakeNumpyLinalgNamespace: return self._array_context.np.sum(abs(ary)**ord)**(1/ord) else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") - # }}} diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 20e208f..c3fecf8 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -54,6 +54,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def __init__(self, queue, allocator=None): super().__init__() + self._force_device_scalars = True self.queue = queue self.allocator = allocator self.np = self._get_fake_numpy_namespace() diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index 2a086b7..8f2816d 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -51,8 +51,7 @@ def make_loopy_program(domains, statements, kernel_data=None, statements, kernel_data=kernel_data, options=_DEFAULT_LOOPY_OPTIONS, - # FIXME: Restore when https://github.com/inducer/loopy/pull/431 is merged - # default_offset=lp.auto, + default_offset=lp.auto, name=name, lang_version=MOST_RECENT_LANGUAGE_VERSION) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index a939b05..9e938b4 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -87,10 +87,24 @@ class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory): force_device_scalars = False +class _PytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): + force_device_scalars = False + + def __call__(self): + from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext + return PytatoPyOpenCLArrayContext(self.get_command_queue()) + + def __str__(self): + return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>" + % (self.device.name.strip(), + self.device.platform.name.strip())) + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { "pyopencl": _PyOpenCLArrayContextFactory, "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory, + "pytato-pyopencl": _PytatoPyOpenCLArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3e9ec2c..eb55fb4 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts([ - "pyopencl", "pyopencl-deprecated", + "pyopencl", "pyopencl-deprecated", "pytato-pyopencl" ]) diff --git a/test/test_utils.py b/test/test_utils.py index 63cadc7..e40dda1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -24,13 +24,16 @@ THE SOFTWARE. """ from arraycontext import ( # noqa: F401 - pytest_generate_tests_for_array_contexts - as pytest_generate_tests, + pytest_generate_tests_for_array_contexts, _acf) import logging logger = logging.getLogger(__name__) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + "pyopencl", "pyopencl-deprecated", "pytato-pyopencl" + ]) + def test_pt_actx_key_stringification_uniqueness(): from arraycontext.impl.pytato.compile import _ary_container_key_stringifier -- GitLab