From c745d2e649d274506d7f3a4acd7be7fe268263d5 Mon Sep 17 00:00:00 2001
From: Alex Fikl <alexfikl@gmail.com>
Date: Thu, 24 Jun 2021 19:59:45 -0500
Subject: [PATCH] Force device scalars on PyOpenCLArrayContext (#43)

* add a flag to force device scalars on the CL array context

* remove local pylintrc from meshmode

* parametrize tests for force_device_scalars

* add docs for force_device_scalars

* add proper typing hints to _loopy_transform_cache

* use np.isscalar instead of isinstance(Number)

* fix type hints for _loopy_transform_cache

* use specific type in _kernel_name_to_wait_event_queue

* add pytest_generate_tests_for_array_context

* add env var to select array context for tests

* fix mypy issues

* add a documented PytestArrayContextFactory

* update test generator to return a closure

* add a way to register factories as strings

* actually fix the implementation

* fix doc formatting

* better names for global variables

* move norm to BaseFakeNumpyLinalgNamespace

* simplify and deduplicate pytest fixture setup
---
 .pylintrc-local.yml                      |   8 -
 arraycontext/__init__.py                 |   7 +-
 arraycontext/fake_numpy.py               |  64 +++++-
 arraycontext/impl/pyopencl/__init__.py   |  40 +++-
 arraycontext/impl/pyopencl/fake_numpy.py |  91 ++-------
 arraycontext/pytest.py                   | 240 +++++++++++++++++++----
 run-mypy.sh                              |   2 +-
 test/test_arraycontext.py                |  22 ++-
 8 files changed, 335 insertions(+), 139 deletions(-)
 delete mode 100644 .pylintrc-local.yml

diff --git a/.pylintrc-local.yml b/.pylintrc-local.yml
deleted file mode 100644
index b3478b1..0000000
--- a/.pylintrc-local.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-- arg: ignore
-  val:
-    - firedrake
-    - to_firedrake.py
-    - from_firedrake.py
-    - test_firedrake_interop.py
-- arg: extension-pkg-whitelist
-  val: mayavi
diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 34df04b..f94fb81 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -52,7 +52,10 @@ from .container.traversal import (
 
 from .impl.pyopencl import PyOpenCLArrayContext
 
-from .pytest import pytest_generate_tests_for_pyopencl_array_context
+from .pytest import (
+        PytestPyOpenCLArrayContextFactory,
+        pytest_generate_tests_for_array_contexts,
+        pytest_generate_tests_for_pyopencl_array_context)
 
 from .loopy import make_loopy_program
 
@@ -81,6 +84,8 @@ __all__ = (
 
         "make_loopy_program",
 
+        "PytestPyOpenCLArrayContextFactory",
+        "pytest_generate_tests_for_array_contexts",
         "pytest_generate_tests_for_pyopencl_array_context"
         )
 
diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index 6b0163d..1f208f3 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -24,7 +24,7 @@ THE SOFTWARE.
 
 
 import numpy as np
-from arraycontext.container import is_array_container
+from arraycontext.container import is_array_container, serialize_container
 from arraycontext.container.traversal import (
         rec_map_array_container, multimapped_over_array_containers)
 
@@ -170,10 +170,72 @@ class BaseFakeNumpyNamespace:
 
 # {{{ BaseFakeNumpyLinalgNamespace
 
+def _scalar_list_norm(ary, ord):
+    if ord is None:
+        ord = 2
+
+    from numbers import Number
+    if ord == np.inf:
+        return max(ary)
+    elif ord == -np.inf:
+        return min(ary)
+    elif isinstance(ord, Number) and ord > 0:
+        return sum(iary**ord for iary in ary)**(1/ord)
+    else:
+        raise NotImplementedError(f"unsupported value of 'ord': {ord}")
+
+
 class BaseFakeNumpyLinalgNamespace:
     def __init__(self, array_context):
         self._array_context = array_context
 
+    def norm(self, ary, ord=None):
+        from numbers import Number
+
+        if isinstance(ary, Number):
+            return abs(ary)
+
+        actx = self._array_context
+
+        try:
+            from meshmode.dof_array import DOFArray, flat_norm
+        except ImportError:
+            pass
+        else:
+            if isinstance(ary, DOFArray):
+                from warnings import warn
+                warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
+                        "(DOFArrays use 2D arrays internally, and "
+                        "actx.np.linalg.norm should compute matrix norms of those.) "
+                        "This will stop working in 2022. "
+                        "Use meshmode.dof_array.flat_norm instead.",
+                        DeprecationWarning, stacklevel=2)
+
+                return flat_norm(ary, ord=ord)
+
+        if is_array_container(ary):
+            return _scalar_list_norm([
+                self.norm(subary, ord=ord)
+                for _, subary in serialize_container(ary)
+                ], ord=ord)
+
+        if ord is None:
+            return self.norm(actx.np.ravel(ary, order="A"), 2)
+
+        if len(ary.shape) != 1:
+            raise NotImplementedError("only vector norms are implemented")
+
+        if ary.size == 0:
+            return 0
+
+        if ord == np.inf:
+            return self._array_context.np.max(abs(ary))
+        elif ord == -np.inf:
+            return self._array_context.np.min(abs(ary))
+        elif isinstance(ord, Number) and ord > 0:
+            return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
+        else:
+            raise NotImplementedError(f"unsupported value of 'ord': {ord}")
 # }}}
 
 
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 154aa44..788dcdb 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -2,6 +2,7 @@
 .. currentmodule:: arraycontext
 .. autoclass:: PyOpenCLArrayContext
 """
+
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
 """
@@ -27,7 +28,7 @@ THE SOFTWARE.
 """
 
 from warnings import warn
-from typing import Sequence, Union
+from typing import Dict, List, Sequence, Optional, Union, TYPE_CHECKING
 
 import numpy as np
 
@@ -37,6 +38,11 @@ from arraycontext.metadata import FirstAxisIsElementsTag
 from arraycontext.context import ArrayContext
 
 
+if TYPE_CHECKING:
+    import pyopencl
+    import loopy as lp
+
+
 # {{{ PyOpenCLArrayContext
 
 class PyOpenCLArrayContext(ArrayContext):
@@ -62,7 +68,11 @@ class PyOpenCLArrayContext(ArrayContext):
         as the allocator can help avoid this cost.
     """
 
-    def __init__(self, queue, allocator=None, wait_event_queue_length=None):
+    def __init__(self,
+            queue: "pyopencl.CommandQueue",
+            allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
+            wait_event_queue_length: Optional[int] = None,
+            force_device_scalars: bool = False) -> None:
         r"""
         :arg wait_event_queue_length: The length of a queue of
             :class:`~pyopencl.Event` objects that are maintained by the
@@ -83,19 +93,31 @@ class PyOpenCLArrayContext(ArrayContext):
 
             For now, *wait_event_queue_length* should be regarded as an
             experimental feature that may change or disappear at any minute.
+
+        :arg force_device_scalars: if *True*, scalar results returned from
+            reductions in :attr:`ArrayContext.np` will be kept on the device.
+            If *False*, the equivalent of :meth:`~ArrayContext.freeze` and
+            :meth:`~ArrayContext.to_numpy` is applied to transfer the results
+            to the host.
         """
+        if not force_device_scalars:
+            warn("Returning host scalars from the array context is deprecated. "
+                    "To return device scalars set 'force_device_scalars=True'. "
+                    "Support for returning host scalars will be removed in 2022.",
+                    DeprecationWarning, stacklevel=2)
+
         import pyopencl as cl
 
         super().__init__()
         self.context = queue.context
         self.queue = queue
         self.allocator = allocator if allocator else None
-
         if wait_event_queue_length is None:
             wait_event_queue_length = 10
 
+        self._force_device_scalars = force_device_scalars
         self._wait_event_queue_length = wait_event_queue_length
-        self._kernel_name_to_wait_event_queue = {}
+        self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {}
 
         if queue.device.type & cl.device_type.GPU:
             if allocator is None:
@@ -110,7 +132,8 @@ class PyOpenCLArrayContext(ArrayContext):
                         "are running Python in debug mode. Use 'python -O' for "
                         "a noticeable speed improvement.")
 
-        self._loopy_transform_cache = {}
+        self._loopy_transform_cache: \
+                Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {}
 
     def _get_fake_numpy_namespace(self):
         from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace
@@ -133,6 +156,9 @@ class PyOpenCLArrayContext(ArrayContext):
         return cl_array.to_device(self.queue, array, allocator=self.allocator)
 
     def to_numpy(self, array):
+        if not self._force_device_scalars and np.isscalar(array):
+            return array
+
         return array.get(queue=self.queue)
 
     def call_loopy(self, t_unit, **kwargs):
@@ -229,7 +255,9 @@ class PyOpenCLArrayContext(ArrayContext):
         return array
 
     def clone(self):
-        return type(self)(self.queue, self.allocator, self._wait_event_queue_length)
+        return type(self)(self.queue, self.allocator,
+                wait_event_queue_length=self._wait_event_queue_length,
+                force_device_scalars=self._force_device_scalars)
 
     @property
     def permits_inplace_modification(self):
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index db3773a..20e0d48 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -29,13 +29,10 @@ THE SOFTWARE.
 from functools import partial
 import operator
 
-import numpy as np
-
 from arraycontext.fake_numpy import \
         BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
 from arraycontext.container.traversal import (rec_multimap_array_container,
                                               rec_map_array_container)
-from arraycontext.container import serialize_container, is_array_container
 
 try:
     import pyopencl as cl  # noqa: F401
@@ -107,14 +104,25 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_multimap_array_container(where_inner, criterion, then, else_)
 
     def sum(self, a, dtype=None):
-        return cl_array.sum(
-                a, dtype=dtype, queue=self._array_context.queue).get()[()]
+        result = cl_array.sum(a, dtype=dtype, queue=self._array_context.queue)
+        if not self._array_context._force_device_scalars:
+            result = result.get()[()]
+
+        return result
 
     def min(self, a):
-        return cl_array.min(a, queue=self._array_context.queue).get()[()]
+        result = cl_array.min(a, queue=self._array_context.queue)
+        if not self._array_context._force_device_scalars:
+            result = result.get()[()]
+
+        return result
 
     def max(self, a):
-        return cl_array.max(a, queue=self._array_context.queue).get()[()]
+        result = cl_array.max(a, queue=self._array_context.queue)
+        if not self._array_context._force_device_scalars:
+            result = result.get()[()]
+
+        return result
 
     def stack(self, arrays, axis=0):
         return rec_multimap_array_container(
@@ -159,75 +167,8 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
 # {{{ fake np.linalg
 
-def _flatten_array(ary):
-    assert isinstance(ary, cl_array.Array)
-
-    if ary.size == 0:
-        # Work around https://github.com/inducer/pyopencl/pull/402
-        return ary._new_with_changes(
-                data=None, offset=0, shape=(0,), strides=(ary.dtype.itemsize,))
-    if ary.flags.f_contiguous:
-        return ary.reshape(-1, order="F")
-    elif ary.flags.c_contiguous:
-        return ary.reshape(-1, order="C")
-    else:
-        raise ValueError("cannot flatten array "
-                f"with strides {ary.strides} of {ary.dtype}")
-
-
 class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
-    def norm(self, ary, ord=None):
-        from numbers import Number
-
-        if isinstance(ary, Number):
-            return abs(ary)
-
-        if ord is None and isinstance(ary, cl_array.Array):
-            if ary.ndim == 1:
-                ord = 2
-            else:
-                # mimics numpy's norm computation
-                return self.norm(_flatten_array(ary), ord=2)
-
-        try:
-            from meshmode.dof_array import DOFArray
-        except ImportError:
-            pass
-        else:
-            if isinstance(ary, DOFArray):
-                from warnings import warn
-                warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
-                        "(DOFArrays use 2D arrays internally, and "
-                        "actx.np.linalg.norm should compute matrix norms of those.) "
-                        "This will stop working in 2022. "
-                        "Use meshmode.dof_array.flat_norm instead.",
-                        DeprecationWarning, stacklevel=2)
-
-                import numpy.linalg as la
-                return la.norm(
-                        [self.norm(_flatten_array(subary), ord=ord)
-                            for _, subary in serialize_container(ary)],
-                        ord=ord)
-
-        if is_array_container(ary):
-            import numpy.linalg as la
-            return la.norm(
-                    [self.norm(subary, ord=ord)
-                        for _, subary in serialize_container(ary)],
-                    ord=ord)
-
-        if len(ary.shape) != 1:
-            raise NotImplementedError("only vector norms are implemented")
-
-        if ary.size == 0:
-            return 0
-
-        if ord == np.inf:
-            return self._array_context.np.max(abs(ary))
-        elif isinstance(ord, Number) and ord > 0:
-            return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
-        else:
-            raise NotImplementedError(f"unsupported value of 'ord': {ord}")
+    pass
 
 # }}}
 
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index 3327aa4..acde939 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -1,9 +1,12 @@
 """
 .. currentmodule:: arraycontext
+
+.. autoclass:: PytestPyOpenCLArrayContextFactory
+
+.. autofunction:: pytest_generate_tests_for_array_contexts
 .. autofunction:: pytest_generate_tests_for_pyopencl_array_context
 """
 
-
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
 """
@@ -28,10 +31,199 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import Any, Callable, Dict, Sequence, Type, Union
+
+import pyopencl as cl
+from arraycontext.context import ArrayContext
+
+
+# {{{ array context factories
+
+class PytestPyOpenCLArrayContextFactory:
+    """
+    .. automethod:: __init__
+    .. automethod:: __call__
+    """
+
+    def __init__(self, device):
+        """
+        :arg device: a :class:`pyopencl.Device`.
+        """
+        self.device = device
+
+    def get_command_queue(self):
+        # Get rid of leftovers from past tests.
+        # CL implementations are surprisingly limited in how many
+        # simultaneous contexts they allow...
+        from pyopencl.tools import clear_first_arg_caches
+        clear_first_arg_caches()
+
+        from gc import collect
+        collect()
+
+        ctx = cl.Context([self.device])
+        return cl.CommandQueue(ctx)
+
+    def __call__(self) -> ArrayContext:
+        raise NotImplementedError
+
+
+class _PyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
+    force_device_scalars = True
+
+    def __call__(self):
+        from arraycontext import PyOpenCLArrayContext
+        return PyOpenCLArrayContext(
+                self.get_command_queue(),
+                force_device_scalars=self.force_device_scalars)
+
+    def __str__(self):
+        return ("<PyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>" %
+                (self.device.name.strip(),
+                 self.device.platform.name.strip()))
+
+
+class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory):
+    force_device_scalars = False
+
+
+_ARRAY_CONTEXT_FACTORY_REGISTRY: \
+        Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = {
+                "pyopencl": _PyOpenCLArrayContextFactory,
+                "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory,
+                }
+
+
+def register_array_context_factory(
+        name: str,
+        factory: Type[PytestPyOpenCLArrayContextFactory]) -> None:
+    if name in _ARRAY_CONTEXT_FACTORY_REGISTRY:
+        raise ValueError(f"factory '{name}' already exists")
+
+    _ARRAY_CONTEXT_FACTORY_REGISTRY[name] = factory
+
+# }}}
+
 
 # {{{ pytest integration
 
-def pytest_generate_tests_for_pyopencl_array_context(metafunc):
+def pytest_generate_tests_for_array_contexts(
+        factories: Sequence[Union[str, Type[PytestPyOpenCLArrayContextFactory]]], *,
+        factory_arg_name: str = "actx_factory",
+        ) -> Callable[[Any], None]:
+    """Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`.
+
+    Using this function in :mod:`pytest` test scripts allows you to use the
+    argument *factory_arg_name*, which is a callable that returns a
+    :class:`~arraycontext.ArrayContext`. All test functions will automatically
+    be run once for each implemented array context. To select specific array
+    context implementations explicitly define, for example,
+
+    .. code-block:: python
+
+        pytest_generate_tests = pytest_generate_tests_for_array_context([
+            "pyopencl",
+            ])
+
+    to use the :mod:`pyopencl`-based array context. For :mod:`pyopencl`-based
+    contexts :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl` is used
+    as a backend, which allows specifying the ``PYOPENCL_TEST`` environment
+    variable for device selection.
+
+    The environment variable ``ARRAYCONTEXT_TEST`` can also be used to
+    overwrite any chosen implementations through *factories*. This is a
+    comma-separated list of known array contexts.
+
+    Current supported implementations include:
+
+    * ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`
+      with ``force_device_scalars=True``.
+    * ``"pyopencl-deprecated"``, which creates a
+      :class:`~arraycontext.PyOpenCLArrayContext` with
+      ``force_device_scalars=False``.
+
+    :arg factories: a list of identifiers or
+        :class:`PytestPyOpenCLArrayContextFactory` classes (not instances)
+        for which to generate test fixtures.
+    """
+
+    # {{{ get all requested array context factories
+
+    import os
+    env_factory_string = os.environ.get("ARRAYCONTEXT_TEST", None)
+
+    if env_factory_string is not None:
+        unique_factories = set(env_factory_string.split(","))
+    else:
+        unique_factories = set(factories)               # type: ignore[arg-type]
+
+    if not unique_factories:
+        raise ValueError("no array context factories were selected")
+
+    unknown_factories = [
+            factory for factory in unique_factories
+            if (isinstance(factory, str)
+                and factory not in _ARRAY_CONTEXT_FACTORY_REGISTRY)
+            ]
+
+    if unknown_factories:
+        if env_factory_string is not None:
+            raise RuntimeError(
+                    "unknown array context factories passed through environment "
+                    f"variable 'ARRAYCONTEXT_TEST': {unknown_factories}")
+        else:
+            raise ValueError(f"unknown array contexts: {unknown_factories}")
+
+    unique_factories = set([
+        _ARRAY_CONTEXT_FACTORY_REGISTRY.get(factory, factory)  # type: ignore[misc]
+        for factory in unique_factories])
+
+    # }}}
+
+    def inner(metafunc):
+        # {{{ get pyopencl devices
+
+        import pyopencl.tools as cl_tools
+        arg_names = cl_tools.get_pyopencl_fixture_arg_names(
+                metafunc, extra_arg_names=[factory_arg_name])
+
+        if not arg_names:
+            return
+
+        arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values()
+
+        # }}}
+
+        # {{{ add array context factory to arguments
+
+        if factory_arg_name in arg_names:
+            if "ctx_factory" in arg_names or "ctx_getter" in arg_names:
+                raise RuntimeError(
+                        f"Cannot use both an '{factory_arg_name}' and a "
+                        "'ctx_factory' / 'ctx_getter' as arguments.")
+
+            arg_values_with_actx = []
+            for arg_dict in arg_values:
+                arg_values_with_actx.extend([
+                    {factory_arg_name: factory(arg_dict["device"]), **arg_dict}
+                    for factory in unique_factories
+                    ])
+        else:
+            arg_values_with_actx = arg_values
+
+        arg_value_tuples = [
+                tuple(arg_dict[name] for name in arg_names)
+                for arg_dict in arg_values_with_actx
+                ]
+
+        # }}}
+
+        metafunc.parametrize(arg_names, arg_value_tuples, ids=ids)
+
+    return inner
+
+
+def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None:
     """Parametrize tests for pytest to use a
     :class:`~arraycontext.PyOpenCLArrayContext`.
 
@@ -42,8 +234,9 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc):
 
     .. code-block:: python
 
-       from arraycontext import pytest_generate_tests_for_pyopencl
-            as pytest_generate_tests
+       from arraycontext import (
+            pytest_generate_tests_for_pyopencl_array_context
+            as pytest_generate_tests)
 
     in your pytest test scripts allows you to use the argument ``actx_factory``,
     in your test functions, and they will automatically be
@@ -55,42 +248,9 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc):
     for device selection.
     """
 
-    import pyopencl as cl
-    from pyopencl.tools import _ContextFactory
-
-    class ArrayContextFactory(_ContextFactory):
-        def __call__(self):
-            ctx = super().__call__()
-            from arraycontext.impl.pyopencl import PyOpenCLArrayContext
-            return PyOpenCLArrayContext(cl.CommandQueue(ctx))
-
-        def __str__(self):
-            return ("<array context factory for <pyopencl.Device '%s' on '%s'>" %
-                    (self.device.name.strip(),
-                     self.device.platform.name.strip()))
-
-    import pyopencl.tools as cl_tools
-    arg_names = cl_tools.get_pyopencl_fixture_arg_names(
-            metafunc, extra_arg_names=["actx_factory"])
-
-    if not arg_names:
-        return
-
-    arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values()
-    if "actx_factory" in arg_names:
-        if "ctx_factory" in arg_names or "ctx_getter" in arg_names:
-            raise RuntimeError("Cannot use both an 'actx_factory' and a "
-                    "'ctx_factory' / 'ctx_getter' as arguments.")
-
-        for arg_dict in arg_values:
-            arg_dict["actx_factory"] = ArrayContextFactory(arg_dict["device"])
-
-    arg_values = [
-            tuple(arg_dict[name] for name in arg_names)
-            for arg_dict in arg_values
-            ]
-
-    metafunc.parametrize(arg_names, arg_values, ids=ids)
+    pytest_generate_tests_for_array_contexts([
+        "pyopencl-deprecated",
+        ], factory_arg_name="actx_factory")(metafunc)
 
 # }}}
 
diff --git a/run-mypy.sh b/run-mypy.sh
index 75f2cff..52241ae 100755
--- a/run-mypy.sh
+++ b/run-mypy.sh
@@ -1,3 +1,3 @@
 #!/bin/bash
 
-python -m mypy arraycontext/ examples/ test/
+python -m mypy --show-error-codes arraycontext examples test
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 40a5c60..1b82254 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -33,14 +33,18 @@ from arraycontext import (
         freeze, thaw,
         FirstAxisIsElementsTag)
 from arraycontext import (  # noqa: F401
-        pytest_generate_tests_for_pyopencl_array_context
-        as pytest_generate_tests,
+        pytest_generate_tests_for_array_contexts,
         _acf)
 
 import logging
 logger = logging.getLogger(__name__)
 
 
+pytest_generate_tests = pytest_generate_tests_for_array_contexts([
+    "pyopencl", "pyopencl-deprecated",
+    ])
+
+
 # {{{ stand-in DOFArray implementation
 
 @with_container_arithmetic(
@@ -398,8 +402,12 @@ def test_dof_array_reductions_same_as_numpy(actx_factory):
         np_red = getattr(np, name)(ary)
         actx_red = getattr(actx.np, name)(actx.from_numpy(ary))
 
-        assert isinstance(actx_red, Number)
-        assert np.allclose(np_red, actx_red)
+        if actx._force_device_scalars:
+            assert actx_red.shape == ()
+        else:
+            assert isinstance(actx_red, Number)
+
+        assert np.allclose(np_red, actx.to_numpy(actx_red))
 
 # }}}
 
@@ -715,10 +723,10 @@ def test_norm_complex(actx_factory, norm_ord):
 
 @pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
 def test_norm_ord_none(actx_factory, ndim):
-    from numpy.random import default_rng
-
     actx = actx_factory()
 
+    from numpy.random import default_rng
+
     rng = default_rng()
     shape = tuple(rng.integers(2, 7, ndim))
 
@@ -727,7 +735,7 @@ def test_norm_ord_none(actx_factory, ndim):
     norm_a_ref = np.linalg.norm(a, ord=None)
     norm_a = actx.np.linalg.norm(actx.from_numpy(a), ord=None)
 
-    np.testing.assert_allclose(norm_a, norm_a_ref)
+    np.testing.assert_allclose(actx.to_numpy(norm_a), norm_a_ref)
 
 
 if __name__ == "__main__":
-- 
GitLab