From 70f251f20af68f3828e9e6953a3a596434fecff6 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sun, 26 Jun 2022 09:03:21 +0300
Subject: [PATCH] pytato: add more array container support

---
 arraycontext/impl/pytato/__init__.py | 479 ++++++++++++++++-----------
 arraycontext/impl/pytato/compile.py  |   2 +-
 2 files changed, 295 insertions(+), 186 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index ed3bff8..8d7e042 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -42,17 +42,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+import abc
 import sys
-from arraycontext.context import ArrayContext, ScalarLike
-from arraycontext.container.traversal import (rec_map_array_container,
-                                              with_array_context)
-from arraycontext.metadata import NameHint
+from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional,
+                    TYPE_CHECKING)
 
 import numpy as np
-from typing import (Any, Callable, Union, TYPE_CHECKING, Tuple, Type, FrozenSet,
-        Dict, Optional)
 from pytools.tag import ToTagSetConvertible, normalize_tags, Tag
-import abc
+
+from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
+from arraycontext.container.traversal import (rec_map_array_container,
+                                              with_array_context)
+from arraycontext.metadata import NameHint
 
 if TYPE_CHECKING:
     import pytato
@@ -105,9 +106,11 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
 
     .. automethod:: compile
     """
-    def __init__(self,
-            *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]]
-             = None) -> None:
+
+    def __init__(
+            self, *,
+            compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None
+            ) -> None:
         """
         :arg compile_trace_callback: A function of three arguments
             *(what, stage, ir)*, where *what* identifies the object
@@ -116,9 +119,10 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
             representation. This interface should be considered
             unstable.
         """
+        super().__init__()
+
         import pytato as pt
         import loopy as lp
-        super().__init__()
         self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
         self._dag_transform_cache: Dict[
                 pt.DictOfNamedArrays,
@@ -136,13 +140,30 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
         from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
         return PytatoFakeNumpyNamespace(self)
 
+    @abc.abstractproperty
+    def _frozen_array_types(self) -> Tuple[Type, ...]:
+        """
+        Returns valid frozen array types for the array context.
+        """
+
+    # {{{ ArrayContext interface
+
     def empty(self, shape, dtype):
-        raise ValueError(f"{type(self).__name__} does not support empty")
+        raise NotImplementedError(
+            f"{type(self).__name__}.empty is not supported")
 
     def zeros(self, shape, dtype):
         import pytato as pt
         return pt.zeros(shape, dtype)
 
+    def empty_like(self, ary):
+        raise NotImplementedError(
+            f"{type(self).__name__}.empty_like is not supported")
+
+    # }}}
+
+    # {{{ compilation
+
     def transform_dag(self, dag: "pytato.DictOfNamedArrays"
                       ) -> "pytato.DictOfNamedArrays":
         """
@@ -158,21 +179,18 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
         return dag
 
     def transform_loopy_program(self, t_unit):
-        raise ValueError(f"{type(self)} does not implement "
-                         "transform_loopy_program. Sub-classes are supposed "
-                         "to implement it.")
-
-    @abc.abstractproperty
-    def frozen_array_types(self) -> Tuple[Type, ...]:
-        """
-        Returns valid frozen array types for the array context.
-        """
-        pass
+        raise ValueError(
+            f"{type(self).__name__} does not implement transform_loopy_program. "
+            "Sub-classes are supposed to implement it.")
 
     @abc.abstractmethod
     def einsum(self, spec, *args, arg_names=None, tagged=()):
         pass
 
+    # }}}
+
+    # {{{ properties
+
     @property
     def permits_inplace_modification(self):
         return False
@@ -185,6 +203,8 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
     def permits_advanced_indexing(self):
         return True
 
+    # }}}
+
 # }}}
 
 
@@ -210,10 +230,10 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
     .. automethod:: compile
     """
-    def __init__(self, queue: "cl.CommandQueue", allocator=None,
-            *,
-            compile_trace_callback: Optional[Callable[[Any, str, Any], None]]
-             = None) -> None:
+    def __init__(
+            self, queue: "cl.CommandQueue", allocator=None, *,
+            compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None
+            ) -> None:
         """
         :arg compile_trace_callback: A function of three arguments
             *(what, stage, ir)*, where *what* identifies the object
@@ -232,63 +252,82 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         # unused, but necessary to keep the context alive
         self.context = self.queue.context
 
-    # {{{ ArrayContext interface
-
-    def clone(self):
-        return type(self)(self.queue, self.allocator)
+    @property
+    def _frozen_array_types(self) -> Tuple[Type, ...]:
+        import pyopencl.array as cla
+        return (cla.Array,)
 
-    def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
+    def _rec_map_container(
+            self, func: Callable[[Array], Array], array: ArrayOrContainer,
+            allowed_types: Optional[Tuple[type, ...]] = None, *,
+            default_scalar: Optional[ScalarLike] = None,
+            strict: bool = False) -> ArrayOrContainer:
         import pytato as pt
-        import pyopencl.array as cla
-        cl_array = cla.to_device(self.queue, array)
-        return pt.make_data_wrapper(cl_array)
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
-    def to_numpy(self, array):
-        if np.isscalar(array):
-            return array
+        if allowed_types is None:
+            allowed_types = (pt.Array, tga.TaggableCLArray)
 
-        cl_array = self.freeze(array)
-        return cl_array.get(queue=self.queue)
+        def _wrapper(ary):
+            if isinstance(ary, allowed_types):
+                return func(ary)
+            elif not strict and isinstance(ary, self._frozen_array_types):
+                from warnings import warn
+                warn(f"Invoking {type(self).__name__}.{func.__name__[1:]} with"
+                    f" {type(ary).__name__} will be unsupported in 2023. Use"
+                    " 'to_tagged_cl_array' to convert instances to"
+                    " TaggableCLArray.", DeprecationWarning, stacklevel=2)
+
+                return func(tga.to_tagged_cl_array(ary))
+            elif np.isscalar(ary):
+                if default_scalar is None:
+                    return ary
+                else:
+                    return np.array(ary).dtype.type(default_scalar)
+            else:
+                raise TypeError(
+                    f"{type(self).__name__}.{func.__name__[1:]} invoked with "
+                    f"an unsupported array type: got '{type(ary).__name__}', "
+                    f"but expected one of {allowed_types}")
 
-    @property
-    def frozen_array_types(self) -> Tuple[Type, ...]:
-        import pyopencl.array as cla
-        return (cla.Array, )
+        return rec_map_array_container(_wrapper, array)
 
-    def call_loopy(self, program, **kwargs):
-        import pytato as pt
-        from pytato.scalar_expr import SCALAR_CLASSES
-        from pytato.loopy import call_loopy
-        from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
+    # {{{ ArrayContext interface
 
-        entrypoint = program.default_entrypoint.name
+    def zeros_like(self, ary):
+        def _zeros_like(array):
+            return self.zeros(array.shape, array.dtype)
 
-        # {{{ preprocess args
+        return self._rec_map_container(_zeros_like, ary, default_scalar=0)
 
-        processed_kwargs = {}
+    def from_numpy(self, array):
+        import pytato as pt
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
-        for kw, arg in sorted(kwargs.items()):
-            if isinstance(arg, (pt.Array,) + SCALAR_CLASSES):
-                pass
-            elif isinstance(arg, TaggableCLArray):
-                arg = self.thaw(arg)
-            else:
-                raise ValueError(f"call_loopy argument '{kw}' expected to be an"
-                                 " instance of 'pytato.Array', 'Number' or"
-                                 f"'TaggableCLArray', got '{type(arg)}'")
+        def _from_numpy(ary):
+            return pt.make_data_wrapper(
+                tga.to_device(self.queue, ary, allocator=self.allocator)
+                )
 
-            processed_kwargs[kw] = arg
+        return with_array_context(
+            self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True),
+            actx=self)
 
-        # }}}
+    def to_numpy(self, array):
+        def _to_numpy(ary):
+            return ary.get(queue=self.queue)
 
-        return call_loopy(program, processed_kwargs, entrypoint)
+        return with_array_context(
+            self._rec_map_container(_to_numpy, self.freeze(array)),
+            actx=None)
 
     def freeze(self, array):
+        if np.isscalar(array):
+            return array
+
         import pytato as pt
         import pyopencl.array as cla
-        import loopy as lp
 
-        from arraycontext.context import ArrayT
         from arraycontext.container.traversal import rec_keyed_map_array_container
         from arraycontext.impl.pytato.utils import (_normalize_pt_expr,
                                                     get_cl_axes_from_pt_axes)
@@ -296,16 +335,15 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
                                                                   TaggableCLArray)
         from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
 
-        array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray,
-                                       pt.Array]] = {}
+        array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {}
         key_to_frozen_subary: Dict[str, TaggableCLArray] = {}
         key_to_pt_arrays: Dict[str, pt.Array] = {}
 
-        def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
-                                     ary: ArrayT):
+        def _record_leaf_ary_in_dict(
+                key: Tuple[Any, ...],
+                ary: Union[cla.Array, TaggableCLArray, pt.Array]) -> None:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             array_as_dict[key_str] = ary
-            return ary
 
         rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)
 
@@ -314,37 +352,37 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         for key, subary in array_as_dict.items():
             if isinstance(subary, TaggableCLArray):
                 key_to_frozen_subary[key] = subary.with_queue(None)
-            elif isinstance(subary, cla.Array):
+            elif isinstance(subary, self._frozen_array_types):
                 from warnings import warn
-                warn("Freezing pyopencl.array.Array will be deprecated in 2023."
-                     " Use `to_tagged_cl_array` to convert the array to"
-                     " TaggableCLArray", DeprecationWarning, stacklevel=2)
-                key_to_frozen_subary[key] = to_tagged_cl_array(
-                    subary.with_queue(None),
-                    axes=None,
-                    tags=frozenset())
+                warn(f"Invoking {type(self).__name__}.freeze with"
+                    f" {type(subary).__name__} will be unsupported in 2023. Use"
+                    " `to_tagged_cl_array` to convert instances to TaggableCLArray.",
+                    DeprecationWarning, stacklevel=2)
+
+                key_to_frozen_subary[key] = (
+                    to_tagged_cl_array(subary.with_queue(None)))
             elif isinstance(subary, pt.DataWrapper):
                 # trivial freeze.
                 key_to_frozen_subary[key] = to_tagged_cl_array(
                     subary.data,
                     axes=get_cl_axes_from_pt_axes(subary.axes),
                     tags=subary.tags)
-            else:
-                if not isinstance(subary, pt.Array):
-                    raise TypeError(f"{type(self).__name__}.freeze invoked "
-                                    f"with non-pytato array of type '{type(array)}'")
-
+            elif isinstance(subary, pt.Array):
                 # Don't be tempted to take shortcuts here, e.g. for empty
                 # arrays, as this will inhibit metadata propagation that
                 # may happen in transform_dag below. See
                 # https://github.com/inducer/arraycontext/pull/167#issuecomment-1151877480
                 key_to_pt_arrays[key] = subary
+            else:
+                raise TypeError(
+                    f"{type(self).__name__}.freeze invoked with an unsupported "
+                    f"array type: got '{type(subary).__name__}', but expected one "
+                    f"of {self.array_types}")
 
         # }}}
 
         pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
-            key_to_pt_arrays)
-
+                key_to_pt_arrays)
         normalized_expr, bound_arguments = _normalize_pt_expr(
                 pt_dict_of_named_arrays)
 
@@ -352,8 +390,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
             pt_prg = self._freeze_prg_cache[normalized_expr]
         except KeyError:
             try:
-                transformed_dag, function_name = \
-                        self._dag_transform_cache[normalized_expr]
+                transformed_dag, function_name = (
+                        self._dag_transform_cache[normalized_expr])
             except KeyError:
                 transformed_dag = self.transform_dag(normalized_expr)
 
@@ -373,16 +411,16 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
                 self._dag_transform_cache[normalized_expr] = (
                         transformed_dag, function_name)
 
+            from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
             pt_prg = pt.generate_loopy(transformed_dag,
-                                       options=lp.Options(return_dict=True,
-                                                          no_numpy=True),
+                                       options=_DEFAULT_LOOPY_OPTIONS,
                                        cl_device=self.queue.device,
                                        function_name=function_name)
             pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
             self._freeze_prg_cache[normalized_expr] = pt_prg
         else:
-            transformed_dag, function_name = \
-                    self._dag_transform_cache[normalized_expr]
+            transformed_dag, function_name = (
+                    self._dag_transform_cache[normalized_expr])
 
         assert len(pt_prg.bound_arguments) == 0
         evt, out_dict = pt_prg(self.queue, **bound_arguments)
@@ -391,47 +429,79 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
         key_to_frozen_subary = {
             **key_to_frozen_subary,
-            **{k: to_tagged_cl_array(v.with_queue(None),
-                                     get_cl_axes_from_pt_axes(transformed_dag[k]
-                                                              .expr
-                                                              .axes),
-                                     transformed_dag[k].expr.tags)
+            **{k: to_tagged_cl_array(
+                    v.with_queue(None),
+                    axes=get_cl_axes_from_pt_axes(transformed_dag[k].expr.axes),
+                    tags=transformed_dag[k].expr.tags)
                for k, v in out_dict.items()}
         }
 
-        def _to_frozen(key: Tuple[Any, ...], ary: ArrayT):
+        def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             return key_to_frozen_subary[key_str]
 
-        return with_array_context(rec_keyed_map_array_container(_to_frozen,
-                                                                array),
-                                  actx=None)
+        return with_array_context(
+                rec_keyed_map_array_container(_to_frozen, array),
+                actx=None)
 
     def thaw(self, array):
         import pytato as pt
         from .utils import get_pt_axes_from_cl_axes
-        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
-                                                                  to_tagged_cl_array)
-        import pyopencl.array as cl_array
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
-        def _rec_thaw(ary):
-            if isinstance(ary, TaggableCLArray):
-                pass
-            elif isinstance(ary, cl_array.Array):
-                ary = to_tagged_cl_array(ary, axes=None, tags=frozenset())
-            else:
-                raise TypeError(f"{type(self).__name__}.thaw expects "
-                                "'TaggableCLArray' or 'cl.array.Array' got "
-                                f"{type(ary)}.")
+        def _thaw(ary):
             return pt.make_data_wrapper(ary.with_queue(self.queue),
                                         axes=get_pt_axes_from_cl_axes(ary.axes),
                                         tags=ary.tags)
 
-        return with_array_context(rec_map_array_container(_rec_thaw, array),
-                                  actx=self)
+        return with_array_context(
+            self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
+            actx=self)
+
+    def tag(self, tags: ToTagSetConvertible, array):
+        def _tag(ary):
+            return ary.tagged(_preprocess_array_tags(tags))
+
+        return self._rec_map_container(_tag, array)
+
+    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
+        def _tag_axis(ary):
+            return ary.with_tagged_axis(iaxis, tags)
+
+        return self._rec_map_container(_tag_axis, array)
 
     # }}}
 
+    # {{{ compilation
+
+    def call_loopy(self, program, **kwargs):
+        import pytato as pt
+        from pytato.scalar_expr import SCALAR_CLASSES
+        from pytato.loopy import call_loopy
+        from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
+
+        entrypoint = program.default_entrypoint.name
+
+        # {{{ preprocess args
+
+        processed_kwargs = {}
+
+        for kw, arg in sorted(kwargs.items()):
+            if isinstance(arg, (pt.Array,) + SCALAR_CLASSES):
+                pass
+            elif isinstance(arg, TaggableCLArray):
+                arg = self.thaw(arg)
+            else:
+                raise ValueError(f"call_loopy argument '{kw}' expected to be an"
+                                 " instance of 'pytato.Array', 'Number' or"
+                                 f"'TaggableCLArray', got '{type(arg)}'")
+
+            processed_kwargs[kw] = arg
+
+        # }}}
+
+        return call_loopy(program, processed_kwargs, entrypoint)
+
     def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
         from .compile import LazilyPyOpenCLCompilingFunctionCaller
         return LazilyPyOpenCLCompilingFunctionCaller(self, f)
@@ -442,39 +512,30 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         dag = pt.transform.materialize_with_mpms(dag)
         return dag
 
-    def tag(self, tags: ToTagSetConvertible, array):
-        return rec_map_array_container(
-                lambda x: x.tagged(_preprocess_array_tags(tags)),
-                array)
-
-    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
-        return rec_map_array_container(
-            lambda x: x.with_tagged_axis(iaxis, tags),
-            array)
-
     def einsum(self, spec, *args, arg_names=None, tagged=()):
-        import pyopencl.array as cla
         import pytato as pt
-        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
-                                                                  to_tagged_cl_array)
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+
         if arg_names is None:
             arg_names = (None,) * len(args)
 
         def preprocess_arg(name, arg):
-            if isinstance(arg, TaggableCLArray):
+            if isinstance(arg, tga.TaggableCLArray):
                 ary = self.thaw(arg)
-            elif isinstance(arg, cla.Array):
+            elif isinstance(arg, self._frozen_array_types):
                 from warnings import warn
-                warn("Passing pyopencl.array.Array to einsum will be "
-                     "deprecated in 2023."
-                     " Use `to_tagged_cl_array` to convert the array to"
-                     " TaggableCLArray.", DeprecationWarning, stacklevel=2)
-                ary = self.thaw(to_tagged_cl_array(arg,
-                                                   axes=None,
-                                                   tags=frozenset()))
-            else:
-                assert isinstance(arg, pt.Array)
+                warn(f"Invoking {type(self).__name__}.einsum with"
+                    f" {type(arg).__name__} will be unsupported in 2023. Use"
+                    " `to_tagged_cl_array` to convert instances to TaggableCLArray.",
+                    DeprecationWarning, stacklevel=2)
+                ary = self.thaw(tga.to_tagged_cl_array(arg))
+            elif isinstance(arg, pt.Array):
                 ary = arg
+            else:
+                raise TypeError(
+                    f"{type(self).__name__}.einsum invoked with an unsupported "
+                    f"array type: got '{type(arg).__name__}', but expected one "
+                    f"of {self.array_types}")
 
             if name is not None:
                 # Tagging Placeholders with naming-related tags is pointless:
@@ -493,6 +554,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
             for name, arg in zip(arg_names, args)
             ]).tagged(_preprocess_array_tags(tagged))
 
+    def clone(self):
+        return type(self)(self.queue, self.allocator)
+
+    # }}}
+
 # }}}
 
 
@@ -521,34 +587,71 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         super().__init__(compile_trace_callback=compile_trace_callback)
         self.array_types = (pt.Array, DeviceArray)
 
-    def clone(self):
-        return type(self)()
+    @property
+    def _frozen_array_types(self) -> Tuple[Type, ...]:
+        from jax.numpy import DeviceArray
+        return (DeviceArray, )
+
+    def _rec_map_container(
+            self, func: Callable[[Array], Array], array: ArrayOrContainer,
+            allowed_types: Optional[Tuple[type, ...]] = None, *,
+            default_scalar: Optional[ScalarLike] = None,
+            strict: bool = False) -> ArrayOrContainer:
+        if allowed_types is None:
+            allowed_types = self.array_types
+
+        def _wrapper(ary):
+            if isinstance(ary, allowed_types):
+                return func(ary)
+            elif np.isscalar(ary):
+                if default_scalar is None:
+                    return ary
+                else:
+                    return np.array(ary).dtype.type(default_scalar)
+            else:
+                raise TypeError(
+                    f"{type(self).__name__}.{func.__name__[1:]} invoked with "
+                    f"an unsupported array type: got '{type(ary).__name__}', "
+                    f"but expected one of {allowed_types}")
+
+        return rec_map_array_container(_wrapper, array)
+
+    # {{{ ArrayContext interface
+
+    def zeros_like(self, ary):
+        def _zeros_like(array):
+            return self.zeros(array.shape, array.dtype)
 
-    def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
+        return self._rec_map_container(_zeros_like, ary, default_scalar=0)
+
+    def from_numpy(self, array):
         import jax
         import pytato as pt
-        return pt.make_data_wrapper(jax.device_put(array))
 
-    def to_numpy(self, array):
-        if np.isscalar(array):
-            return array
+        def _from_numpy(ary):
+            return pt.make_data_wrapper(jax.device_put(ary))
+
+        return with_array_context(
+            self._rec_map_container(_from_numpy, array, (np.ndarray,)),
+            actx=self)
 
+    def to_numpy(self, array):
         import jax
-        return jax.device_get(self.freeze(array))
 
-    @property
-    def frozen_array_types(self) -> Tuple[Type, ...]:
-        from jax.numpy import DeviceArray
-        return (DeviceArray, )
+        def _to_numpy(ary):
+            return jax.device_get(ary)
 
-    def call_loopy(self, program, **kwargs):
-        raise ValueError(f"{type(self)} does not support calling loopy.")
+        return with_array_context(
+            self._rec_map_container(_to_numpy, self.freeze(array)),
+            actx=None)
 
     def freeze(self, array):
+        if np.isscalar(array):
+            return array
+
         import pytato as pt
 
         from jax.numpy import DeviceArray
-        from arraycontext.context import ArrayT
         from arraycontext.container.traversal import rec_keyed_map_array_container
         from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
 
@@ -557,10 +660,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         key_to_pt_arrays: Dict[str, pt.Array] = {}
 
         def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
-                                     ary: Union[DeviceArray, pt.Array]):
+                                     ary: Union[DeviceArray, pt.Array]) -> None:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             array_as_dict[key_str] = ary
-            return ary
 
         rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)
 
@@ -572,12 +674,13 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
             elif isinstance(subary, pt.DataWrapper):
                 # trivial freeze.
                 key_to_frozen_subary[key] = subary.data.block_until_ready()
-            else:
-                if not isinstance(subary, pt.Array):
-                    raise TypeError(f"{type(self).__name__}.freeze invoked "
-                                    f"with non-pytato array of type '{type(array)}'")
-
+            elif isinstance(subary, pt.Array):
                 key_to_pt_arrays[key] = subary
+            else:
+                raise TypeError(
+                    f"{type(self).__name__}.freeze invoked with an unsupported "
+                    f"array type: got '{type(subary).__name__}', but expected one "
+                    f"of {self.array_types}")
 
         # }}}
 
@@ -593,59 +696,59 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
                for k, v in out_dict.items()}
         }
 
-        def _to_frozen(key: Tuple[Any, ...], ary: ArrayT):
+        def _to_frozen(key: Tuple[Any, ...], ary) -> DeviceArray:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             return key_to_frozen_subary[key_str]
 
-        return with_array_context(rec_keyed_map_array_container(_to_frozen,
-                                                                array),
-                                  actx=None)
+        return with_array_context(
+            rec_keyed_map_array_container(_to_frozen, array),
+            actx=None)
 
     def thaw(self, array):
         import pytato as pt
-        from jax.numpy import DeviceArray
 
-        def _rec_thaw(ary):
-            if isinstance(ary, DeviceArray):
-                pass
-            else:
-                raise TypeError(f"{type(self).__name__}.thaw expects "
-                                f"'jax.DeviceArray' got {type(ary)}.")
+        def _thaw(ary):
             return pt.make_data_wrapper(ary)
 
-        return with_array_context(rec_map_array_container(_rec_thaw, array),
-                                  actx=self)
+        return with_array_context(
+            self._rec_map_container(_thaw, array, self._frozen_array_types),
+            actx=self)
 
     def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
         from .compile import LazilyJAXCompilingFunctionCaller
         return LazilyJAXCompilingFunctionCaller(self, f)
 
     def tag(self, tags: ToTagSetConvertible, array):
-        import pytato as pt
         from jax.numpy import DeviceArray
 
-        def _rec_tag(ary):
+        def _tag(ary):
             if isinstance(ary, DeviceArray):
                 return ary
             else:
-                assert isinstance(ary, pt.Array)
                 return ary.tagged(_preprocess_array_tags(tags))
 
-        return rec_map_array_container(_rec_tag, array)
+        return self._rec_map_container(_tag, array)
 
     def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
-        import pytato as pt
         from jax.numpy import DeviceArray
 
-        def _rec_tag_axis(ary):
+        def _tag_axis(ary):
             if isinstance(ary, DeviceArray):
                 return ary
             else:
-                assert isinstance(ary, pt.Array)
                 return ary.with_tagged_axis(iaxis, tags)
 
-        return rec_map_array_container(_rec_tag_axis,
-                                       array)
+        return self._rec_map_container(_tag_axis, array)
+
+    # }}}
+
+    # {{{ compilation
+
+    def call_loopy(self, program, **kwargs):
+        raise NotImplementedError(
+            "Calling loopy on JAX arrays is not supported. Maybe rewrite"
+            " the loopy kernel as numpy-flavored array operations using"
+            " ArrayContext.np.")
 
     def einsum(self, spec, *args, arg_names=None, tagged=()):
         import pytato as pt
@@ -656,9 +759,13 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         def preprocess_arg(name, arg):
             if isinstance(arg, DeviceArray):
                 ary = self.thaw(arg)
-            else:
-                assert isinstance(arg, pt.Array)
+            elif isinstance(arg, pt.Array):
                 ary = arg
+            else:
+                raise TypeError(
+                    f"{type(self).__name__}.einsum invoked with an unsupported "
+                    f"array type: got '{type(arg).__name__}', but expected one "
+                    f"of {self.array_types}")
 
             if name is not None:
                 # Tagging Placeholders with naming-related tags is pointless:
@@ -677,7 +784,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
             for name, arg in zip(arg_names, args)
             ]).tagged(_preprocess_array_tags(tagged))
 
-# }}}
+    def clone(self):
+        return type(self)()
 
+# }}}
 
 # vim: foldmethod=marker
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 9e92adf..53170ca 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -523,7 +523,7 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
         elif isinstance(arg, pt.array.DataWrapper):
             # got a Datawrapper => simply gets its data
             arg = arg.data
-        elif isinstance(arg, actx.frozen_array_types):
+        elif isinstance(arg, actx._frozen_array_types):
             # got a frozen array  => do nothing
             pass
         elif isinstance(arg, pt.Array):
-- 
GitLab