From 97940584273f457d89dcab0dcc0b3cf376d8dc68 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Fri, 25 Jun 2021 17:18:55 -0500
Subject: [PATCH] small fixes

---
 arraycontext/fake_numpy.py           |  1 -
 arraycontext/impl/pytato/__init__.py |  1 +
 arraycontext/loopy.py                |  3 +--
 arraycontext/pytest.py               | 14 ++++++++++++++
 test/test_arraycontext.py            |  2 +-
 test/test_utils.py                   |  7 +++++--
 6 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index cb1b3fc..1f208f3 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -236,7 +236,6 @@ class BaseFakeNumpyLinalgNamespace:
             return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
         else:
             raise NotImplementedError(f"unsupported value of 'ord': {ord}")
-
 # }}}
 
 
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 20e208f..c3fecf8 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -54,6 +54,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
 
     def __init__(self, queue, allocator=None):
         super().__init__()
+        self._force_device_scalars = True
         self.queue = queue
         self.allocator = allocator
         self.np = self._get_fake_numpy_namespace()
diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py
index 2a086b7..8f2816d 100644
--- a/arraycontext/loopy.py
+++ b/arraycontext/loopy.py
@@ -51,8 +51,7 @@ def make_loopy_program(domains, statements, kernel_data=None,
             statements,
             kernel_data=kernel_data,
             options=_DEFAULT_LOOPY_OPTIONS,
-            # FIXME: Restore when https://github.com/inducer/loopy/pull/431 is merged
-            # default_offset=lp.auto,
+            default_offset=lp.auto,
             name=name,
             lang_version=MOST_RECENT_LANGUAGE_VERSION)
 
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index a939b05..9e938b4 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -87,10 +87,24 @@ class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory):
     force_device_scalars = False
 
 
+class _PytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
+    force_device_scalars = False
+
+    def __call__(self):
+        from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
+        return PytatoPyOpenCLArrayContext(self.get_command_queue())
+
+    def __str__(self):
+        return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>"
+                % (self.device.name.strip(),
+                 self.device.platform.name.strip()))
+
+
 _ARRAY_CONTEXT_FACTORY_REGISTRY: \
         Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = {
                 "pyopencl": _PyOpenCLArrayContextFactory,
                 "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory,
+                "pytato-pyopencl": _PytatoPyOpenCLArrayContextFactory,
                 }
 
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 3e9ec2c..eb55fb4 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
 
 
 pytest_generate_tests = pytest_generate_tests_for_array_contexts([
-    "pyopencl", "pyopencl-deprecated",
+    "pyopencl", "pyopencl-deprecated", "pytato-pyopencl"
     ])
 
 
diff --git a/test/test_utils.py b/test/test_utils.py
index 63cadc7..e40dda1 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -24,13 +24,16 @@ THE SOFTWARE.
 """
 
 from arraycontext import (  # noqa: F401
-        pytest_generate_tests_for_array_contexts
-        as pytest_generate_tests,
+        pytest_generate_tests_for_array_contexts,
         _acf)
 
 import logging
 logger = logging.getLogger(__name__)
 
+pytest_generate_tests = pytest_generate_tests_for_array_contexts([
+    "pyopencl", "pyopencl-deprecated", "pytato-pyopencl"
+    ])
+
 
 def test_pt_actx_key_stringification_uniqueness():
     from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
-- 
GitLab