Skip to content
Snippets Groups Projects
__init__.py 32.43 KiB
from __future__ import annotations


__doc__ = """
.. currentmodule:: arraycontext

A :mod:`pytato`-based array context defers the evaluation of an array until it is
frozen. The execution contexts for the evaluations are specific to an
:class:`~arraycontext.ArrayContext` type. For example,
:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to
JIT-compile and execute the array expressions.

The following :mod:`pytato`-based array contexts are provided:

.. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext


Compiling a Python callable (Internal)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. automodule:: arraycontext.impl.pytato.compile


Utils
^^^^^

.. automodule:: arraycontext.impl.pytato.utils
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import abc
import sys
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import numpy as np

from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
    Array,
    ArrayContext,
    ArrayOrContainer,
    ScalarLike,
    UntransformedCodeWarning,
)
from arraycontext.metadata import NameHint


if TYPE_CHECKING:
    import loopy as lp
    import pyopencl as cl
    import pytato

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
    import pyopencl as cl

import logging


logger = logging.getLogger(__name__)


# {{{ tag conversion

def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
    tags = normalize_tags(tags)

    name_hints = [tag for tag in tags if isinstance(tag, NameHint)]
    if name_hints:
        name_hint, = name_hints

        from pytato.tags import PrefixNamed
        prefix_nameds = [tag for tag in tags if isinstance(tag, PrefixNamed)]

        if prefix_nameds:
            prefix_named, = prefix_nameds
            from warnings import warn
            warn("When converting a "
                    f"arraycontext.metadata.NameHint('{name_hint.name}') "
                    "to pytato.tags.PrefixNamed, "
                    f"PrefixNamed('{prefix_named.prefix}') "
                    "was already present.", stacklevel=1)

        tags = (
                (tags | frozenset({PrefixNamed(name_hint.name)}))
                - {name_hint})

    return tags

# }}}


class _NotOnlyDataWrappers(Exception):  # noqa: N818
    pass


# {{{ _BasePytatoArrayContext

class _BasePytatoArrayContext(ArrayContext, abc.ABC):
    """
    An abstract :class:`ArrayContext` that uses :mod:`pytato` data types to
    represent.

    .. automethod:: __init__

    .. automethod:: transform_dag

    .. automethod:: compile
    """

    def __init__(
            self, *,
            compile_trace_callback: Callable[[Any, str, Any], None] | None = None
            ) -> None:
        """
        :arg compile_trace_callback: A function of three arguments
            *(what, stage, ir)*, where *what* identifies the object
            being compiled, *stage* is a string describing the compilation
            pass, and *ir* is an object containing the intermediate
            representation. This interface should be considered
            unstable.
        """
        super().__init__()

        import pytato as pt
        self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
        self._dag_transform_cache: dict[
                pt.DictOfNamedArrays,
                tuple[pt.DictOfNamedArrays, str]] = {}

        if compile_trace_callback is None:
            def _compile_trace_callback(what, stage, ir):
                pass

            compile_trace_callback = _compile_trace_callback

        self._compile_trace_callback = compile_trace_callback

    def _get_fake_numpy_namespace(self):
        from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
        return PytatoFakeNumpyNamespace(self)

    @abc.abstractproperty
    def _frozen_array_types(self) -> tuple[type, ...]:
        """
        Returns valid frozen array types for the array context.
        """

    # {{{ compilation

    def transform_dag(self, dag: pytato.DictOfNamedArrays
                      ) -> pytato.DictOfNamedArrays:
        """
        Returns a transformed version of *dag*. Sub-classes are supposed to
        override this method to implement context-specific transformations on
        *dag* (most likely to perform domain-specific optimizations). Every
        :mod:`pytato` DAG that is compiled to a GPU-kernel is
        passed through this routine.

        :arg dag: An instance of :class:`pytato.DictOfNamedArrays`
        :returns: A transformed version of *dag*.
        """
        return dag

    def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
        from warnings import warn
        warn("Using the base "
                f"{type(self).__name__}.transform_loopy_program "
                "to transform a translation unit. "
                "This is a no-op and will result in unoptimized C code for"
                "the requested optimization, all in a single statement."
                "This will work, but is unlikely to be performant."
                f"Instead, subclass {type(self).__name__} and implement "
                "the specific transform logic required to transform the program "
                "for your package or application. Check higher-level packages "
                "(e.g. meshmode), which may already have subclasses you may want "
                "to build on.",
                UntransformedCodeWarning, stacklevel=2)

        return t_unit

    @abc.abstractmethod
    def einsum(self, spec, *args, arg_names=None, tagged=()):
        pass

    # }}}

    # {{{ properties

    @property
    def permits_inplace_modification(self):
        return False

    @property
    def supports_nonscalar_broadcasting(self):
        return True

    @property
    def permits_advanced_indexing(self):
        return True

    def get_target(self):
        return None

    # }}}

# }}}


# {{{ PytatoPyOpenCLArrayContext

class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
    """
    An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
    the arrays targeting OpenCL for offloading operations.

    .. attribute:: queue

        A :class:`pyopencl.CommandQueue`.

    .. attribute:: allocator

        A :mod:`pyopencl` memory allocator. Can also be None (default) or False
        to use the default allocator.

    .. automethod:: __init__

    .. automethod:: transform_dag

    .. automethod:: compile
    """
    def __init__(
            self, queue: cl.CommandQueue, allocator=None, *,
            use_memory_pool: bool | None = None,
            compile_trace_callback: Callable[[Any, str, Any], None] | None = None,

            # do not use: only for testing
            _force_svm_arg_limit: int | None = None,
            ) -> None:
        """
        :arg compile_trace_callback: A function of three arguments
            *(what, stage, ir)*, where *what* identifies the object
            being compiled, *stage* is a string describing the compilation
            pass, and *ir* is an object containing the intermediate
            representation. This interface should be considered
            unstable.
        """
        if allocator is not None and use_memory_pool is not None:
            raise TypeError("may not specify both allocator and use_memory_pool")

        self.using_svm = None

        if allocator is None:
            from pyopencl.characterize import has_coarse_grain_buffer_svm
            has_svm = has_coarse_grain_buffer_svm(queue.device)
            if has_svm:
                self.using_svm = True

                from pyopencl.tools import SVMAllocator
                allocator = SVMAllocator(queue.context, queue=queue)

                if use_memory_pool:
                    from pyopencl.tools import SVMPool
                    allocator = SVMPool(allocator)
            else:
                self.using_svm = False

                from pyopencl.tools import ImmediateAllocator
                allocator = ImmediateAllocator(queue)

                if use_memory_pool:
                    from pyopencl.tools import MemoryPool
                    allocator = MemoryPool(allocator)
        else:
            # Check whether the passed allocator allocates SVM
            try:
                from pyopencl import SVMPointer
                mem = allocator(4)
                if isinstance(mem, SVMPointer):
                    self.using_svm = True
                else:
                    self.using_svm = False
            except ImportError:
                self.using_svm = False

        import pyopencl.array as cla
        import pytato as pt
        super().__init__(compile_trace_callback=compile_trace_callback)
        self.queue = queue

        self.allocator = allocator
        self.array_types = (pt.Array, cla.Array)

        # unused, but necessary to keep the context alive
        self.context = self.queue.context

        self._force_svm_arg_limit = _force_svm_arg_limit

    @property
    def _frozen_array_types(self) -> tuple[type, ...]:
        import pyopencl.array as cla
        return (cla.Array,)

    def _rec_map_container(
            self, func: Callable[[Array], Array], array: ArrayOrContainer,
            allowed_types: tuple[type, ...] | None = None, *,
            default_scalar: ScalarLike | None = None,
            strict: bool = False) -> ArrayOrContainer:
        import pytato as pt

        import arraycontext.impl.pyopencl.taggable_cl_array as tga

        if allowed_types is None:
            allowed_types = (pt.Array, tga.TaggableCLArray)

        def _wrapper(ary):
            if isinstance(ary, allowed_types):
                return func(ary)
            elif np.isscalar(ary):
                if default_scalar is None:
                    return ary
                else:
                    return np.array(ary).dtype.type(default_scalar)
            else:
                raise TypeError(
                    f"{func.__qualname__} invoked with "
                    f"an unsupported array type: got '{type(ary).__name__}', "
                    f"but expected one of {allowed_types}")

        return rec_map_array_container(_wrapper, array)

    # {{{ ArrayContext interface

    def from_numpy(self, array):
        import pytato as pt

        import arraycontext.impl.pyopencl.taggable_cl_array as tga

        def _from_numpy(ary):
            return pt.make_data_wrapper(
                tga.to_device(self.queue, ary, allocator=self.allocator)
                )

        return with_array_context(
            self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True),
            actx=self)

    def to_numpy(self, array):
        def _to_numpy(ary):
            return ary.get(queue=self.queue)

        return with_array_context(
            self._rec_map_container(_to_numpy, self.freeze(array)),
            actx=None)

    @memoize_method
    def get_target(self):
        import pyopencl as cl
        import pyopencl.characterize as cl_char

        dev = self.queue.device

        if (
                self._force_svm_arg_limit is not None
                or (
                    self.using_svm and dev.type & cl.device_type.GPU
                    and cl_char.has_coarse_grain_buffer_svm(dev))):

            if dev.max_parameter_size == 4352:
                # Nvidia devices and PTXAS declare a limit of 4352 bytes,
                # which is incorrect. The CUDA documentation at
                # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#function-parameters
                # mentions a limit of 4KB, which is also incorrect.
                # As far as I can tell, the actual limit is around 4080
                # bytes, at least on a K40. Reducing the limit further
                # in order to be on the safe side.

                # Note that the naming convention isn't super consistent
                # for Nvidia GPUs, so that we only use the maximum
                # parameter size to determine if it is an Nvidia GPU.

                limit = 4096-200

                from warnings import warn
                warn("Running on an Nvidia GPU, reducing the argument "
                    f"size limit from 4352 to {limit}.", stacklevel=1)
            else:
                limit = dev.max_parameter_size

            if self._force_svm_arg_limit is not None:
                limit = self._force_svm_arg_limit

            logger.info(
                    "limiting argument buffer size for %s to %d bytes",
                    dev, limit)

            from arraycontext.impl.pytato.utils import (
                ArgSizeLimitingPytatoLoopyPyOpenCLTarget,
            )
            return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
        else:
            return super().get_target()

    def freeze(self, array):
        if np.isscalar(array):
            return array

        import pyopencl.array as cla
        import pytato as pt

        from arraycontext.container.traversal import rec_keyed_map_array_container
        from arraycontext.impl.pyopencl.taggable_cl_array import (
            TaggableCLArray,
            to_tagged_cl_array,
        )
        from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
        from arraycontext.impl.pytato.utils import (
            _normalize_pt_expr,
            get_cl_axes_from_pt_axes,
        )

        array_as_dict: dict[str, cla.Array | TaggableCLArray | pt.Array] = {}
        key_to_frozen_subary: dict[str, TaggableCLArray] = {}
        key_to_pt_arrays: dict[str, pt.Array] = {}

        def _record_leaf_ary_in_dict(
                key: tuple[Any, ...],
                ary: cla.Array | TaggableCLArray | pt.Array) -> None:
            key_str = "_ary" + _ary_container_key_stringifier(key)
            array_as_dict[key_str] = ary

        rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)

        # {{{ remove any non pytato arrays from array_as_dict

        for key, subary in array_as_dict.items():
            if isinstance(subary, TaggableCLArray):
                key_to_frozen_subary[key] = subary.with_queue(None)
            elif isinstance(subary, self._frozen_array_types):
                from warnings import warn
                warn(f"Invoking {type(self).__name__}.freeze with"
                    f" {type(subary).__name__} will be unsupported in 2023. Use"
                    " `to_tagged_cl_array` to convert instances to TaggableCLArray.",
                    DeprecationWarning, stacklevel=2)

                key_to_frozen_subary[key] = (
                    to_tagged_cl_array(subary.with_queue(None)))
            elif isinstance(subary, pt.DataWrapper):
                # trivial freeze.
                key_to_frozen_subary[key] = to_tagged_cl_array(
                    subary.data,
                    axes=get_cl_axes_from_pt_axes(subary.axes),
                    tags=subary.tags)
            elif isinstance(subary, pt.Array):
                # Don't be tempted to take shortcuts here, e.g. for empty
                # arrays, as this will inhibit metadata propagation that
                # may happen in transform_dag below. See
                # https://github.com/inducer/arraycontext/pull/167#issuecomment-1151877480
                key_to_pt_arrays[key] = subary
            else:
                raise TypeError(
                    f"{type(self).__name__}.freeze invoked with an unsupported "
                    f"array type: got '{type(subary).__name__}', but expected one "
                    f"of {self.array_types}")

        # }}}

        def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
            key_str = "_ary" + _ary_container_key_stringifier(key)
            return key_to_frozen_subary[key_str]

        if not key_to_pt_arrays:
            # all cl arrays => no need to perform any codegen
            return with_array_context(
                    rec_keyed_map_array_container(_to_frozen, array),
                    actx=None)

        pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
                key_to_pt_arrays)
        normalized_expr, bound_arguments = _normalize_pt_expr(
                pt_dict_of_named_arrays)

        try:
            pt_prg = self._freeze_prg_cache[normalized_expr]
        except KeyError:
            try:
                transformed_dag, function_name = (
                        self._dag_transform_cache[normalized_expr])
            except KeyError:
                transformed_dag = self.transform_dag(normalized_expr)

                from pytato.tags import PrefixNamed
                name_hint_tags = []
                for subary in key_to_pt_arrays.values():
                    name_hint_tags.extend(subary.tags_of_type(PrefixNamed))

                from pytools import common_prefix
                name_hint = common_prefix([nh.prefix for nh in name_hint_tags])

                # All name_hint_tags shared at least some common prefix.
                function_name = f"frozen_{name_hint}" if name_hint else "frozen_result"

                self._dag_transform_cache[normalized_expr] = (
                        transformed_dag, function_name)

            from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
            opts = _DEFAULT_LOOPY_OPTIONS
            assert opts.return_dict

            pt_prg = pt.generate_loopy(transformed_dag,
                                       options=opts,
                                       cl_device=self.queue.device,
                                       function_name=function_name,
                                       target=self.get_target()
                                       ).bind_to_context(self.context)
            pt_prg = pt_prg.with_transformed_translation_unit(
                    self.transform_loopy_program)
            self._freeze_prg_cache[normalized_expr] = pt_prg
        else:
            transformed_dag, function_name = (
                    self._dag_transform_cache[normalized_expr])

        assert len(pt_prg.bound_arguments) == 0
        evt, out_dict = pt_prg(self.queue,
                allocator=self.allocator,
                **bound_arguments)
        evt.wait()
        assert len(set(out_dict) & set(key_to_frozen_subary)) == 0

        key_to_frozen_subary = {
            **key_to_frozen_subary,
            **{k: to_tagged_cl_array(
                    v.with_queue(None),
                    axes=get_cl_axes_from_pt_axes(transformed_dag[k].expr.axes),
                    tags=transformed_dag[k].expr.tags)
               for k, v in out_dict.items()}
        }

        return with_array_context(
                rec_keyed_map_array_container(_to_frozen, array),
                actx=None)

    def thaw(self, array):
        import pytato as pt

        import arraycontext.impl.pyopencl.taggable_cl_array as tga
        from .utils import get_pt_axes_from_cl_axes

        def _thaw(ary):
            return pt.make_data_wrapper(ary.with_queue(self.queue),
                                        axes=get_pt_axes_from_cl_axes(ary.axes),
                                        tags=ary.tags)

        return with_array_context(
            self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
            actx=self)

    def freeze_thaw(self, array):
        import pytato as pt

        import arraycontext.impl.pyopencl.taggable_cl_array as tga

        def _ft(ary):
            if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)):
                return ary
            else:
                raise _NotOnlyDataWrappers()

        try:
            return with_array_context(
                self._rec_map_container(_ft, array),
                actx=self)
        except _NotOnlyDataWrappers:
            return super().freeze_thaw(array)

    def tag(self, tags: ToTagSetConvertible, array):
        def _tag(ary):
            return ary.tagged(_preprocess_array_tags(tags))

        return self._rec_map_container(_tag, array)

    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
        def _tag_axis(ary):
            return ary.with_tagged_axis(iaxis, tags)

        return self._rec_map_container(_tag_axis, array)

    # }}}

    # {{{ compilation

    def call_loopy(self, program, **kwargs):
        import pytato as pt
        from pytato.loopy import call_loopy
        from pytato.scalar_expr import SCALAR_CLASSES

        from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray

        entrypoint = program.default_entrypoint.name

        # {{{ preprocess args

        processed_kwargs = {}

        for kw, arg in sorted(kwargs.items()):
            if isinstance(arg, (pt.Array, *SCALAR_CLASSES)):
                pass
            elif isinstance(arg, TaggableCLArray):
                arg = self.thaw(arg)
            else:
                raise ValueError(f"call_loopy argument '{kw}' expected to be an"
                                 " instance of 'pytato.Array', 'Number' or"
                                 f"'TaggableCLArray', got '{type(arg)}'")

            processed_kwargs[kw] = arg

        # }}}

        return call_loopy(program, processed_kwargs, entrypoint)

    def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
        from .compile import LazilyPyOpenCLCompilingFunctionCaller
        return LazilyPyOpenCLCompilingFunctionCaller(self, f)

    def transform_dag(self, dag: pytato.DictOfNamedArrays
                      ) -> pytato.DictOfNamedArrays:
        import pytato as pt
        dag = pt.transform.materialize_with_mpms(dag)
        return dag

    def einsum(self, spec, *args, arg_names=None, tagged=()):
        import pytato as pt

        import arraycontext.impl.pyopencl.taggable_cl_array as tga

        if arg_names is None:
            arg_names = (None,) * len(args)

        def preprocess_arg(name, arg):
            if isinstance(arg, tga.TaggableCLArray):
                ary = self.thaw(arg)
            elif isinstance(arg, self._frozen_array_types):
                from warnings import warn
                warn(f"Invoking {type(self).__name__}.einsum with"
                    f" {type(arg).__name__} will be unsupported in 2023. Use"
                    " `to_tagged_cl_array` to convert instances to TaggableCLArray.",
                    DeprecationWarning, stacklevel=2)
                ary = self.thaw(tga.to_tagged_cl_array(arg))
            elif isinstance(arg, pt.Array):
                ary = arg
            else:
                raise TypeError(
                    f"{type(self).__name__}.einsum invoked with an unsupported "
                    f"array type: got '{type(arg).__name__}', but expected one "
                    f"of {self.array_types}")

            if name is not None:  # noqa: SIM102
                # Tagging Placeholders with naming-related tags is pointless:
                # They already have names. It's also counterproductive, as
                # multiple placeholders with the same name that are not
                # also the same object are not allowed, and this would produce
                # a different Placeholder object of the same name.
                if (not isinstance(ary, pt.Placeholder)
                        and not ary.tags_of_type(NameHint)):
                    ary = ary.tagged(NameHint(name))

            return ary

        return pt.einsum(spec, *[
            preprocess_arg(name, arg)
            for name, arg in zip(arg_names, args, strict=True)
            ]).tagged(_preprocess_array_tags(tagged))

    def clone(self):
        return type(self)(self.queue, self.allocator)

    # }}}

# }}}


# {{{ PytatoJAXArrayContext

class PytatoJAXArrayContext(_BasePytatoArrayContext):
    """
    An arraycontext that uses :mod:`pytato` to represent the thawed state of
    the arrays and compiles the expressions using
    :class:`pytato.target.python.JAXPythonTarget`.
    """

    def __init__(self,
            *,
            compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
            ) -> None:
        """
        :arg compile_trace_callback: A function of three arguments
            *(what, stage, ir)*, where *what* identifies the object
            being compiled, *stage* is a string describing the compilation
            pass, and *ir* is an object containing the intermediate
            representation. This interface should be considered
            unstable.
        """
        import jax.numpy as jnp
        import pytato as pt
        super().__init__(compile_trace_callback=compile_trace_callback)
        self.array_types = (pt.Array, jnp.ndarray)

    @property
    def _frozen_array_types(self) -> tuple[type, ...]:
        import jax.numpy as jnp
        return (jnp.ndarray, )

    def _rec_map_container(
            self, func: Callable[[Array], Array], array: ArrayOrContainer,
            allowed_types: tuple[type, ...] | None = None, *,
            default_scalar: ScalarLike | None = None,
            strict: bool = False) -> ArrayOrContainer:
        if allowed_types is None:
            allowed_types = self.array_types

        def _wrapper(ary):
            if isinstance(ary, allowed_types):
                return func(ary)
            elif np.isscalar(ary):
                if default_scalar is None:
                    return ary
                else:
                    return np.array(ary).dtype.type(default_scalar)
            else:
                raise TypeError(
                    f"{type(self).__name__}.{func.__name__[1:]} invoked with "
                    f"an unsupported array type: got '{type(ary).__name__}', "
                    f"but expected one of {allowed_types}")

        return rec_map_array_container(_wrapper, array)

    # {{{ ArrayContext interface

    def from_numpy(self, array):
        import jax
        import pytato as pt

        def _from_numpy(ary):
            return pt.make_data_wrapper(jax.device_put(ary))

        return with_array_context(
            self._rec_map_container(_from_numpy, array, (np.ndarray,)),
            actx=self)

    def to_numpy(self, array):
        import jax

        def _to_numpy(ary):
            return jax.device_get(ary)

        return with_array_context(
            self._rec_map_container(_to_numpy, self.freeze(array)),
            actx=None)

    def freeze(self, array):
        if np.isscalar(array):
            return array

        import jax.numpy as jnp
        import pytato as pt

        from arraycontext.container.traversal import rec_keyed_map_array_container
        from arraycontext.impl.pytato.compile import _ary_container_key_stringifier

        array_as_dict: dict[str, jnp.ndarray | pt.Array] = {}
        key_to_frozen_subary: dict[str, jnp.ndarray] = {}
        key_to_pt_arrays: dict[str, pt.Array] = {}

        def _record_leaf_ary_in_dict(key: tuple[Any, ...],
                                     ary: jnp.ndarray | pt.Array) -> None:
            key_str = "_ary" + _ary_container_key_stringifier(key)
            array_as_dict[key_str] = ary

        rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)

        # {{{ remove any non pytato arrays from array_as_dict

        for key, subary in array_as_dict.items():
            if isinstance(subary, jnp.ndarray):
                key_to_frozen_subary[key] = subary.block_until_ready()
            elif isinstance(subary, pt.DataWrapper):
                # trivial freeze.
                key_to_frozen_subary[key] = subary.data.block_until_ready()
            elif isinstance(subary, pt.Array):
                key_to_pt_arrays[key] = subary
            else:
                raise TypeError(
                    f"{type(self).__name__}.freeze invoked with an unsupported "
                    f"array type: got '{type(subary).__name__}', but expected one "
                    f"of {self.array_types}")

        # }}}

        def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
            key_str = "_ary" + _ary_container_key_stringifier(key)
            return key_to_frozen_subary[key_str]

        if not key_to_pt_arrays:
            # all cl arrays => no need to perform any codegen
            return with_array_context(
                    rec_keyed_map_array_container(_to_frozen, array),
                    actx=None)

        pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays)
        transformed_dag = self.transform_dag(pt_dict_of_named_arrays)
        pt_prg = pt.generate_jax(transformed_dag, jit=True)
        out_dict = pt_prg()
        assert len(set(out_dict) & set(key_to_frozen_subary)) == 0

        key_to_frozen_subary = {
            **key_to_frozen_subary,
            **{k: v.block_until_ready()
               for k, v in out_dict.items()}
        }

        return with_array_context(
            rec_keyed_map_array_container(_to_frozen, array),
            actx=None)

    def thaw(self, array):
        import pytato as pt

        def _thaw(ary):
            return pt.make_data_wrapper(ary)

        return with_array_context(
            self._rec_map_container(_thaw, array, self._frozen_array_types),
            actx=self)

    def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
        from .compile import LazilyJAXCompilingFunctionCaller
        return LazilyJAXCompilingFunctionCaller(self, f)

    def tag(self, tags: ToTagSetConvertible, array):
        def _tag(ary):
            import jax.numpy as jnp
            if isinstance(ary, jnp.ndarray):
                return ary
            else:
                return ary.tagged(_preprocess_array_tags(tags))

        return self._rec_map_container(_tag, array)

    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
        def _tag_axis(ary):
            import jax.numpy as jnp
            if isinstance(ary, jnp.ndarray):
                return ary
            else:
                return ary.with_tagged_axis(iaxis, tags)

        return self._rec_map_container(_tag_axis, array)

    # }}}

    # {{{ compilation

    def call_loopy(self, program, **kwargs):
        raise NotImplementedError(
            "Calling loopy on JAX arrays is not supported. Maybe rewrite"
            " the loopy kernel as numpy-flavored array operations using"
            " ArrayContext.np.")

    def einsum(self, spec, *args, arg_names=None, tagged=()):
        import pytato as pt
        if arg_names is None:
            arg_names = (None,) * len(args)

        def preprocess_arg(name, arg):
            import jax.numpy as jnp
            if isinstance(arg, jnp.ndarray):
                ary = self.thaw(arg)
            elif isinstance(arg, pt.Array):
                ary = arg
            else:
                raise TypeError(
                    f"{type(self).__name__}.einsum invoked with an unsupported "
                    f"array type: got '{type(arg).__name__}', but expected one "
                    f"of {self.array_types}")

            if name is not None:  # noqa: SIM102
                # Tagging Placeholders with naming-related tags is pointless:
                # They already have names. It's also counterproductive, as
                # multiple placeholders with the same name that are not
                # also the same object are not allowed, and this would produce
                # a different Placeholder object of the same name.
                if (not isinstance(ary, pt.Placeholder)
                        and not ary.tags_of_type(NameHint)):
                    ary = ary.tagged(NameHint(name))

            return ary

        return pt.einsum(spec, *[
            preprocess_arg(name, arg)
            for name, arg in zip(arg_names, args, strict=True)
            ]).tagged(_preprocess_array_tags(tagged))

    def clone(self):
        return type(self)()

# }}}

# }}}

# vim: foldmethod=marker