From 46a5cfe61cee7fe972cd375e05277258ca7274d5 Mon Sep 17 00:00:00 2001
From: Matt Smith <mjsmith6@illinois.edu>
Date: Wed, 4 Sep 2024 16:04:33 -0500
Subject: [PATCH] Forbid setting `force_device_scalars=False` (#278)

* forbid force_device_scalars=False

* attempt to improve compatibility
---
 arraycontext/__init__.py                 |  2 -
 arraycontext/impl/pyopencl/__init__.py   | 41 +++++++-------
 arraycontext/impl/pyopencl/fake_numpy.py | 42 +++-----------
 arraycontext/pytest.py                   | 72 +++++-------------------
 test/test_arraycontext.py                |  7 +--
 5 files changed, 43 insertions(+), 121 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 4e0ba83..c40117e 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -90,7 +90,6 @@ from .pytest import (
     PytestArrayContextFactory,
     PytestPyOpenCLArrayContextFactory,
     pytest_generate_tests_for_array_contexts,
-    pytest_generate_tests_for_pyopencl_array_context,
 )
 from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
 
@@ -139,7 +138,6 @@ __all__ = (
     "multimapped_over_array_containers",
     "outer",
     "pytest_generate_tests_for_array_contexts",
-    "pytest_generate_tests_for_pyopencl_array_context",
     "rec_map_array_container",
     "rec_map_reduce_array_container",
     "rec_multimap_array_container",
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index de188cb..83dc9c0 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -84,7 +84,7 @@ class PyOpenCLArrayContext(ArrayContext):
             queue: pyopencl.CommandQueue,
             allocator: Optional[pyopencl.tools.AllocatorBase] = None,
             wait_event_queue_length: Optional[int] = None,
-            force_device_scalars: bool = False) -> None:
+            force_device_scalars: Optional[bool] = None) -> None:
         r"""
         :arg wait_event_queue_length: The length of a queue of
             :class:`~pyopencl.Event` objects that are maintained by the
@@ -105,21 +105,15 @@ class PyOpenCLArrayContext(ArrayContext):
 
             For now, *wait_event_queue_length* should be regarded as an
             experimental feature that may change or disappear at any minute.
-
-        :arg force_device_scalars: if *True*, scalar results returned from
-            reductions in :attr:`ArrayContext.np` will be kept on the device.
-            If *False*, the equivalent of :meth:`~ArrayContext.freeze` and
-            :meth:`~ArrayContext.to_numpy` is applied to transfer the results
-            to the host.
         """
-        if not force_device_scalars:
-            warn("Configuring the PyOpenCLArrayContext to return host scalars "
-                    "from reductions is deprecated. "
-                    "To configure the PyOpenCLArrayContext to return "
-                    "device scalars, pass 'force_device_scalars=True' to the "
-                    "constructor. "
-                    "Support for returning host scalars will be removed in 2022.",
-                    DeprecationWarning, stacklevel=2)
+        if force_device_scalars is not None:
+            warn(
+                "`force_device_scalars` is deprecated and will be removed in 2025.",
+                DeprecationWarning, stacklevel=2)
+
+            if not force_device_scalars:
+                raise ValueError(
+                    "Passing force_device_scalars=False is not allowed.")
 
         import pyopencl as cl
         import pyopencl.array as cl_array
@@ -131,7 +125,12 @@ class PyOpenCLArrayContext(ArrayContext):
         if wait_event_queue_length is None:
             wait_event_queue_length = 10
 
-        self._force_device_scalars = force_device_scalars
+        self._force_device_scalars = True
+        # Subclasses might still be using the old
+        # "force_devices_scalars: bool = False" interface, in which case we need
+        # to explicitly pass force_device_scalars=True in clone()
+        self._passed_force_device_scalars = force_device_scalars is not None
+
         self._wait_event_queue_length = wait_event_queue_length
         self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {}
 
@@ -267,9 +266,13 @@ class PyOpenCLArrayContext(ArrayContext):
         return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}
 
     def clone(self):
-        return type(self)(self.queue, self.allocator,
-                wait_event_queue_length=self._wait_event_queue_length,
-                force_device_scalars=self._force_device_scalars)
+        if self._passed_force_device_scalars:
+            return type(self)(self.queue, self.allocator,
+                    wait_event_queue_length=self._wait_event_queue_length,
+                    force_device_scalars=True)
+        else:
+            return type(self)(self.queue, self.allocator,
+                    wait_event_queue_length=self._wait_event_queue_length)
 
     # }}}
 
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 848870a..ac79245 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -169,15 +169,11 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
     # {{{ linear algebra
 
     def vdot(self, x, y, dtype=None):
-        result = rec_multimap_reduce_array_container(
+        return rec_multimap_reduce_array_container(
                 sum,
                 partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
                 x, y)
 
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-        return result
-
     # }}}
 
     # {{{ logic functions
@@ -190,15 +186,11 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                 return np.int8(all([ary]))
             return ary.all(queue=queue)
 
-        result = rec_map_reduce_array_container(
+        return rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.minimum, queue=queue)),
                 _all,
                 a)
 
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-        return result
-
     def any(self, a):
         queue = self._array_context.queue
 
@@ -207,15 +199,11 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                 return np.int8(any([ary]))
             return ary.any(queue=queue)
 
-        result = rec_map_reduce_array_container(
+        return rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.maximum, queue=queue)),
                 _any,
                 a)
 
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-        return result
-
     def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
         actx = self._array_context
         queue = actx.queue
@@ -251,11 +239,7 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                             in zip(serialized_x, serialized_y)],
                         true_ary)
 
-        result = rec_equal(a, b)
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-
-        return result
+        return rec_equal(a, b)
 
     # FIXME: This should be documentation, not a comment.
     # These are here mainly because some arrays may choose to interpret
@@ -305,11 +289,7 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
             return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)
 
-        result = rec_map_reduce_array_container(sum, _rec_sum, a)
-
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-        return result
+        return rec_map_reduce_array_container(sum, _rec_sum, a)
 
     def maximum(self, x, y):
         return rec_multimap_array_container(
@@ -327,15 +307,11 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                 raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
             return cl_array.max(ary, queue=queue)
 
-        result = rec_map_reduce_array_container(
+        return rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.maximum, queue=queue)),
                 _rec_max,
                 a)
 
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-        return result
-
     max = amax
 
     def minimum(self, x, y):
@@ -354,15 +330,11 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                 raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
             return cl_array.min(ary, queue=queue)
 
-        result = rec_map_reduce_array_container(
+        return rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.minimum, queue=queue)),
                 _rec_min,
                 a)
 
-        if not self._array_context._force_device_scalars:
-            result = result.get()[()]
-        return result
-
     min = amin
 
     def absolute(self, a):
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index 088c7e3..c778154 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -5,7 +5,6 @@
 .. autoclass:: PytestPyOpenCLArrayContextFactory
 
 .. autofunction:: pytest_generate_tests_for_array_contexts
-.. autofunction:: pytest_generate_tests_for_pyopencl_array_context
 """
 
 __copyright__ = """
@@ -88,7 +87,16 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
 
 
 class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
-    force_device_scalars = True
+    # Deprecated, remove in 2025.
+    _force_device_scalars = True
+
+    @property
+    def force_device_scalars(self):
+        from warnings import warn
+        warn(
+            "force_device_scalars is deprecated and will be removed in 2025.",
+             DeprecationWarning, stacklevel=2)
+        return self._force_device_scalars
 
     @property
     def actx_class(self):
@@ -117,8 +125,7 @@ class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFact
 
         return self.actx_class(
                 queue,
-                allocator=alloc,
-                force_device_scalars=self.force_device_scalars)
+                allocator=alloc)
 
     def __str__(self):
         return (f"<{self.actx_class.__name__} "
@@ -126,11 +133,6 @@ class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFact
             f"on '{self.device.platform.name.strip()}'>>")
 
 
-class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
-        _PytestPyOpenCLArrayContextFactoryWithClass):
-    force_device_scalars = False
-
-
 class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
     @classmethod
     def is_available(cls) -> bool:
@@ -245,8 +247,6 @@ class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
 _ARRAY_CONTEXT_FACTORY_REGISTRY: \
         Dict[str, Type[PytestArrayContextFactory]] = {
                 "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
-                "pyopencl-deprecated":
-                _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars,
                 "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
                 "pytato:jax": _PytestPytatoJaxArrayContextFactory,
                 "eagerjax": _PytestEagerJaxArrayContextFactory,
@@ -285,10 +285,7 @@ def pytest_generate_tests_for_array_contexts(
             "pyopencl",
             ])
 
-    to use the :mod:`pyopencl`-based array context. For :mod:`pyopencl`-based
-    contexts :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl` is used
-    as a backend, which allows specifying the ``PYOPENCL_TEST`` environment
-    variable for device selection.
+    to use the :mod:`pyopencl`-based array context.
 
     The environment variable ``ARRAYCONTEXT_TEST`` can also be used to
     overwrite any chosen implementations through *factories*. This is a
@@ -296,11 +293,7 @@ def pytest_generate_tests_for_array_contexts(
 
     Current supported implementations include:
 
-    * ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`
-      with ``force_device_scalars=True``.
-    * ``"pyopencl-deprecated"``, which creates a
-      :class:`~arraycontext.PyOpenCLArrayContext` with
-      ``force_device_scalars=False``.
+    * ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`.
     * ``"pytato-pyopencl"``, which creates a
       :class:`~arraycontext.PytatoPyOpenCLArrayContext`.
 
@@ -404,45 +397,6 @@ def pytest_generate_tests_for_array_contexts(
 
     return inner
 
-
-def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None:
-    """Parametrize tests for pytest to use a
-    :class:`~arraycontext.PyOpenCLArrayContext`.
-
-    Performs device enumeration analogously to
-    :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl`.
-
-    Using the line:
-
-    .. code-block:: python
-
-       from arraycontext import (
-            pytest_generate_tests_for_pyopencl_array_context
-            as pytest_generate_tests)
-
-    in your pytest test scripts allows you to use the argument ``actx_factory``,
-    in your test functions, and they will automatically be
-    run once for each OpenCL device/platform in the system, as appropriate,
-    with an argument-less function that returns an
-    :class:`~arraycontext.ArrayContext` when called.
-
-    It also allows you to specify the ``PYOPENCL_TEST`` environment variable
-    for device selection.
-    """
-
-    from warnings import warn
-    warn("pytest_generate_tests_for_pyopencl_array_context is deprecated. "
-            "Use 'pytest_generate_tests = "
-            "arraycontext.pytest_generate_tests_for_array_contexts"
-            "([\"pyopencl-deprecated\"])' instead. "
-            "pytest_generate_tests_for_pyopencl_array_context will stop working "
-            "in 2022.",
-            DeprecationWarning, stacklevel=2)
-
-    pytest_generate_tests_for_array_contexts([
-        "pyopencl-deprecated",
-        ], factory_arg_name="actx_factory")(metafunc)
-
 # }}}
 
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 47d8390..7bea0dc 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -665,12 +665,7 @@ def test_reductions_same_as_numpy(actx_factory, op):
     actx_red = getattr(actx.np, op)(actx.from_numpy(ary))
     actx_red = actx.to_numpy(actx_red)
 
-    from numbers import Number
-
-    if isinstance(actx, PyOpenCLArrayContext) and (not actx._force_device_scalars):
-        assert isinstance(actx_red, Number)
-    else:
-        assert actx_red.shape == ()
+    assert actx_red.shape == ()
 
     assert np.allclose(np_red, actx_red)
 
-- 
GitLab