From 1cc7f5eaba8719bcfe582fca6bfc02e0f2c9c8e0 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Fri, 21 Jan 2022 00:27:12 -0600
Subject: [PATCH] Generalize Pytato Array Context to allow multiple targets

Also implements PytatoJAXTarget.

Co-authored-by: Alexandru Fikl <alexfikl@gmail.com>
---
 arraycontext/__init__.py             |   5 +-
 arraycontext/impl/pytato/__init__.py | 262 +++++++++++++++++++-----
 arraycontext/impl/pytato/compile.py  | 290 ++++++++++++++++++++-------
 3 files changed, 441 insertions(+), 116 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index d9861d6..2dc4abd 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -65,8 +65,7 @@ from .container.traversal import (
 
 from .impl.pyopencl import PyOpenCLArrayContext
 from .impl.pytato import (PytatoPyOpenCLArrayContext,
-                          PytatoJAXArrayContext,
-                          _BasePytatoArrayContext)
+                          PytatoJAXArrayContext)
 from .impl.jax import EagerJAXArrayContext
 
 from .pytest import (
@@ -106,7 +105,7 @@ __all__ = (
         "outer",
 
         "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
-        "PytatoJAXArrayContext", "_BasePytatoArrayContext",
+        "PytatoJAXArrayContext",
         "EagerJAXArrayContext",
 
         "make_loopy_program",
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 216bd1e..909e432 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -10,6 +10,7 @@ JIT-compile and execute the array expressions.
 Following :mod:`pytato`-based array context are provided:
 
 .. autoclass:: PytatoPyOpenCLArrayContext
+.. autoclass:: PytatoJAXArrayContext
 
 
 Compiling a python callable
@@ -44,14 +45,85 @@ THE SOFTWARE.
 from arraycontext.context import ArrayContext, _ScalarLike
 from arraycontext.container.traversal import rec_map_array_container
 import numpy as np
-from typing import Any, Callable, Union, TYPE_CHECKING
+from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type
 from pytools.tag import ToTagSetConvertible
+import abc
 
 if TYPE_CHECKING:
     import pytato
 
 
-class PytatoPyOpenCLArrayContext(ArrayContext):
+class _BasePytatoArrayContext(ArrayContext, abc.ABC):
+    """
+    An abstract :class:`ArrayContext` that uses :mod:`pytato` data types to
+    represent.
+
+    .. automethod:: __init__
+
+    .. automethod:: transform_dag
+
+    .. automethod:: compile
+    """
+    def __init__(self):
+        super().__init__()
+        self._freeze_prg_cache = {}
+        self._dag_transform_cache = {}
+
+    def _get_fake_numpy_namespace(self):
+        from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
+        return PytatoFakeNumpyNamespace(self)
+
+    def empty(self, shape, dtype):
+        raise ValueError(f"{type(self).__name__} does not support empty")
+
+    def zeros(self, shape, dtype):
+        import pytato as pt
+        return pt.zeros(shape, dtype)
+
+    def transform_dag(self, dag: "pytato.DictOfNamedArrays"
+                      ) -> "pytato.DictOfNamedArrays":
+        """
+        Returns a transformed version of *dag*. Sub-classes are supposed to
+        override this method to implement context-specific transformations on
+        *dag* (most likely to perform domain-specific optimizations). Every
+        :mod:`pytato` DAG that is compiled to a GPU-kernel is
+        passed through this routine.
+
+        :arg dag: An instance of :class:`pytato.DictOfNamedArrays`
+        :returns: A transformed version of *dag*.
+        """
+        return dag
+
+    def transform_loopy_program(self, t_unit):
+        raise ValueError(f"{type(self)} does not implement "
+                         "transform_loopy_program. Sub-classes are supposed "
+                         "to implement it.")
+
+    @abc.abstractproperty
+    def frozen_array_types(self) -> Tuple[Type, ...]:
+        """
+        Returns valid frozen array types for the array context.
+        """
+        pass
+
+    @abc.abstractmethod
+    def einsum(self, spec, *args, arg_names=None, tagged=()):
+        pass
+
+    @property
+    def permits_inplace_modification(self):
+        return False
+
+    @property
+    def supports_nonscalar_broadcasting(self):
+        return True
+
+    @property
+    def permits_advanced_indexing(self):
+        return True
+
+
+class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
     """
     A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
     the arrays targeting OpenCL for offloading operations.
@@ -79,28 +151,15 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         self.queue = queue
         self.allocator = allocator
         self.array_types = (pt.Array, cla.Array)
-        self._freeze_prg_cache = {}
-        self._dag_transform_cache = {}
 
         # unused, but necessary to keep the context alive
         self.context = self.queue.context
 
-    def _get_fake_numpy_namespace(self):
-        from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
-        return PytatoFakeNumpyNamespace(self)
-
     # {{{ ArrayContext interface
 
     def clone(self):
         return type(self)(self.queue, self.allocator)
 
-    def empty(self, shape, dtype):
-        raise ValueError("PytatoPyOpenCLArrayContext does not support empty")
-
-    def zeros(self, shape, dtype):
-        import pytato as pt
-        return pt.zeros(shape, dtype)
-
     def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
         import pytato as pt
         import pyopencl.array as cla
@@ -114,6 +173,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         cl_array = self.freeze(array)
         return cl_array.get(queue=self.queue)
 
+    @property
+    def frozen_array_types(self) -> Tuple[Type, ...]:
+        import pyopencl.array as cla
+        return (cla.Array, )
+
     def call_loopy(self, program, **kwargs):
         import pytato as pt
         from pytato.scalar_expr import SCALAR_CLASSES
@@ -167,8 +231,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
                                       axes=get_cl_axes_from_pt_axes(array.axes),
                                       tags=array.tags)
         if not isinstance(array, pt.Array):
-            raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
-                            f"non-pytato array of type '{type(array)}'")
+            raise TypeError(f"{type(self).__name__}.freeze invoked "
+                            f"with non-pytato array of type '{type(array)}'")
 
         # {{{ early exit for 0-sized arrays
 
@@ -227,7 +291,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         elif isinstance(array, cl_array.Array):
             array = to_tagged_cl_array(array, axes=None, tags=frozenset())
         else:
-            raise TypeError("PytatoPyOpenCLArrayContext.thaw expects "
+            raise TypeError(f"{type(self).__name__}.thaw expects "
                             "'TaggableCLArray' or 'cl.array.Array' got "
                             f"{type(array)}.")
 
@@ -238,30 +302,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     # }}}
 
     def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
-        from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller
-        return LazilyCompilingFunctionCaller(self, f)
-
-    def transform_loopy_program(self, t_unit):
-        raise ValueError("PytatoPyOpenCLArrayContext does not implement "
-                         "transform_loopy_program. Sub-classes are supposed "
-                         "to implement it.")
+        from .compile import LazilyPyOpenCLCompilingFunctionCaller
+        return LazilyPyOpenCLCompilingFunctionCaller(self, f)
 
     def transform_dag(self, dag: "pytato.DictOfNamedArrays"
                       ) -> "pytato.DictOfNamedArrays":
-        """
-        Returns a transformed version of *dag*. Sub-classes are supposed to
-        override this method to implement context-specific transformations on
-        *dag* (most likely to perform domain-specific optimizations). Every
-        :mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is
-        passed through this routine.
-
-        :arg dag: An instance of :class:`pytato.DictOfNamedArrays`
-        :returns: A transformed version of *dag*.
-        """
         import pytato as pt
-
         dag = pt.transform.materialize_with_mpms(dag)
-
         return dag
 
     def tag(self, tags: ToTagSetConvertible, array):
@@ -315,14 +362,139 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
             for name, arg in zip(arg_names, args)
             ])
 
-    @property
-    def permits_inplace_modification(self):
-        return False
 
-    @property
-    def supports_nonscalar_broadcasting(self):
-        return True
+class PytatoJAXArrayContext(_BasePytatoArrayContext):
+    """
+    An arraycontext that uses :mod:`pytato` to represent the thawed state of
+    the arrays and compiles the expressions using
+    :class:`pytato.target.python.JAXPythonTarget`.
+    """
+
+    def __init__(self):
+        import pytato as pt
+        from jax.numpy import DeviceArray
+        super().__init__()
+        self.array_types = (pt.Array, DeviceArray)
+
+    def clone(self):
+        return type(self)()
+
+    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+        import jax
+        import pytato as pt
+        return pt.make_data_wrapper(jax.device_put(array))
+
+    def to_numpy(self, array):
+        if np.isscalar(array):
+            return array
+
+        import jax
+        return jax.device_get(self.freeze(array))
 
     @property
-    def permits_advanced_indexing(self):
-        return True
+    def frozen_array_types(self) -> Tuple[Type, ...]:
+        from jax.numpy import DeviceArray
+        return (DeviceArray, )
+
+    def call_loopy(self, program, **kwargs):
+        raise ValueError(f"{type(self)} does not support calling loopy.")
+
+    def freeze(self, array):
+        import pytato as pt
+        from jax.numpy import DeviceArray
+
+        if isinstance(array, DeviceArray):
+            return array.block_until_ready()
+        if not isinstance(array, pt.Array):
+            raise TypeError(f"{type(self)}.freeze invoked with "
+                            f"non-pytato array of type '{type(array)}'")
+
+        from arraycontext.impl.pytato.utils import _normalize_pt_expr
+        pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
+                {"_actx_out": array})
+
+        normalized_expr, bound_arguments = _normalize_pt_expr(
+                pt_dict_of_named_arrays)
+
+        try:
+            pt_prg = self._freeze_prg_cache[normalized_expr]
+        except KeyError:
+            pt_prg = pt.generate_jax(self.transform_dag(normalized_expr),
+                                     jit=True)
+            self._freeze_prg_cache[normalized_expr] = pt_prg
+
+        assert len(pt_prg.bound_arguments) == 0
+        out_dict = pt_prg(**bound_arguments)
+
+        return out_dict["_actx_out"].block_until_ready()
+
+    def thaw(self, array):
+        import pytato as pt
+
+        if not isinstance(array, self.frozen_array_types):
+            raise TypeError(f"{type(self)}.thaw expects jax device arrays, got "
+                            f"{type(array)}")
+
+        return pt.make_data_wrapper(array)
+
+    def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
+        from .compile import LazilyJAXCompilingFunctionCaller
+        return LazilyJAXCompilingFunctionCaller(self, f)
+
+    def tag(self, tags: ToTagSetConvertible, array):
+        import pytato as pt
+        from jax.numpy import DeviceArray
+
+        def _rec_tag(ary):
+            if isinstance(ary, DeviceArray):
+                return ary
+            else:
+                assert isinstance(ary, pt.Array)
+                return ary.tagged(tags)
+
+        return rec_map_array_container(_rec_tag, array)
+
+    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
+        import pytato as pt
+        from jax.numpy import DeviceArray
+
+        def _rec_tag_axis(ary):
+            if isinstance(ary, DeviceArray):
+                return ary
+            else:
+                assert isinstance(ary, pt.Array)
+                return ary.with_tagged_axis(iaxis, tags)
+
+        return rec_map_array_container(_rec_tag_axis,
+                                       array)
+
+    def einsum(self, spec, *args, arg_names=None, tagged=()):
+        import pytato as pt
+        from jax.numpy import DeviceArray
+        if arg_names is None:
+            arg_names = (None,) * len(args)
+
+        def preprocess_arg(name, arg):
+            if isinstance(arg, DeviceArray):
+                ary = self.thaw(arg)
+            else:
+                assert isinstance(arg, pt.Array)
+                ary = arg
+
+            if name is not None:
+                from pytato.tags import PrefixNamed
+
+                # Tagging Placeholders with naming-related tags is pointless:
+                # They already have names. It's also counterproductive, as
+                # multiple placeholders with the same name that are not
+                # also the same object are not allowed, and this would produce
+                # a different Placeholder object of the same name.
+                if not isinstance(ary, pt.Placeholder):
+                    ary = ary.tagged(PrefixNamed(name))
+
+            return ary
+
+        return pt.einsum(spec, *[
+            preprocess_arg(name, arg)
+            for name, arg in zip(arg_names, args)
+            ])
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 8b20a7c..129a2c4 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -1,6 +1,7 @@
 """
-.. currentmodule:: arraycontext.impl.pytato.compile
-.. autoclass:: LazilyCompilingFunctionCaller
+.. autoclass:: BaseLazilyCompilingFunctionCaller
+.. autoclass:: LazilyPyOpenCLCompilingFunctionCaller
+.. autoclass:: LazilyJAXCompilingFunctionCaller
 .. autoclass:: CompiledFunction
 .. autoclass:: FromArrayContextCompile
 """
@@ -30,16 +31,17 @@ THE SOFTWARE.
 
 from arraycontext.container import (ArrayContainer, is_array_container_type,
                                     ArrayT)
-from arraycontext import PytatoPyOpenCLArrayContext
+from arraycontext.impl.pytato import (_BasePytatoArrayContext,
+                                      PytatoJAXArrayContext,
+                                      PytatoPyOpenCLArrayContext)
 from arraycontext.container.traversal import rec_keyed_map_array_container
 
 import abc
 import numpy as np
-from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet
+from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet, Type
 from dataclasses import dataclass, field
 from pyrsistent import pmap, PMap
 
-import pyopencl.array as cla
 import pytato as pt
 import itertools
 from pytools.tag import Tag
@@ -65,7 +67,7 @@ class FromArrayContextCompile(Tag):
 
 class AbstractInputDescriptor:
     """
-    Used internally in :class:`LazilyCompilingFunctionCaller` to characterize
+    Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize
     an input.
     """
     def __eq__(self, other):
@@ -90,7 +92,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
 
 def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
     """
-    Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Stringifies an
+    Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
     array-container's component's key. Goals of this routine:
 
     * No two different keys should have the same stringification
@@ -118,7 +120,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
                                                             AbstractInputDescriptor]\
                                                        ]":
     """
-    Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Extracts
+    Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts
     mappings from argument id to argument values and from argument id to
     :class:`AbstractInputDescriptor`. See
     :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's
@@ -190,9 +192,9 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
 
 def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
     """
-    Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the
+    Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the
     placeholder version of an argument to
-    :attr:`LazilyCompilingFunctionCaller.f`.
+    :attr:`BaseLazilyCompilingFunctionCaller.f`.
     """
     if np.isscalar(arg):
         name = arg_id_to_name[(kw,)]
@@ -221,12 +223,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
 
 
 @dataclass
-class LazilyCompilingFunctionCaller:
+class BaseLazilyCompilingFunctionCaller:
     """
-    Records a side-effect-free callable
-    :attr:`LazilyCompilingFunctionCaller.f` that can be specialized for the
-    input types with which :meth:`LazilyCompilingFunctionCaller.__call__` is
-    invoked.
+    Records a side-effect-free callable :attr:`f` that can be specialized for
+    the input types with which :meth:`__call__` is invoked.
 
     .. attribute:: f
 
@@ -235,48 +235,26 @@ class LazilyCompilingFunctionCaller:
     .. automethod:: __call__
     """
 
-    actx: PytatoPyOpenCLArrayContext
+    actx: _BasePytatoArrayContext
     f: Callable[..., Any]
     program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]",
                         "CompiledFunction"] = field(default_factory=lambda: {})
 
-    def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
-        from pytato.target.loopy import BoundPyOpenCLProgram
-
-        import loopy as lp
-
-        with ProcessLogger(logger, "transform_dag"):
-            pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
-
-        name_in_program_to_tags = {
-            name: out.tags
-            for name, out in pt_dict_of_named_arrays._data.items()}
-        name_in_program_to_axes = {
-            name: out.axes
-            for name, out in pt_dict_of_named_arrays._data.items()}
-
-        with ProcessLogger(logger, "generate_loopy"):
-            pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
-                                               options=lp.Options(
-                                                   return_dict=True,
-                                                   no_numpy=True),
-                                               cl_device=self.actx.queue.device)
-            assert isinstance(pytato_program, BoundPyOpenCLProgram)
+    # {{{ abstract interface
 
-        with ProcessLogger(logger, "transform_loopy_program"):
+    def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays):
+        raise NotImplementedError
 
-            pytato_program = (pytato_program
-                              .with_transformed_program(
-                                  lambda x: x.with_kernel(
-                                      x.default_entrypoint
-                                      .tagged(FromArrayContextCompile()))))
+    @property
+    def compiled_function_returning_array_container_class(
+            self) -> Type["CompiledFunction"]:
+        raise NotImplementedError
 
-            pytato_program = (pytato_program
-                              .with_transformed_program(self
-                                                        .actx
-                                                        .transform_loopy_program))
+    @property
+    def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]:
+        raise NotImplementedError
 
-        return pytato_program, name_in_program_to_tags, name_in_program_to_axes
+    # }}}
 
     def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
             input_id_to_name_in_program, output_id_to_name_in_program,
@@ -286,8 +264,8 @@ class LazilyCompilingFunctionCaller:
             dict_of_named_arrays = pt.make_dict_of_named_arrays(
                 {output_id: ary_or_dict_of_named_arrays})
             pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
-                self._dag_to_transformed_loopy_prg(dict_of_named_arrays))
-            return CompiledFunctionReturningArray(
+                self._dag_to_transformed_pytato_prg(dict_of_named_arrays))
+            return self.compiled_function_returning_array_class(
                 self.actx, pytato_program,
                 input_id_to_name_in_program=input_id_to_name_in_program,
                 output_tags=name_in_program_to_tags[output_id],
@@ -295,8 +273,8 @@ class LazilyCompilingFunctionCaller:
                 output_name=output_id)
         elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
             pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
-                self._dag_to_transformed_loopy_prg(ary_or_dict_of_named_arrays))
-            return CompiledFunctionReturningArrayContainer(
+                self._dag_to_transformed_pytato_prg(ary_or_dict_of_named_arrays))
+            return self.compiled_function_returning_array_container_class(
                     self.actx, pytato_program,
                     input_id_to_name_in_program=input_id_to_name_in_program,
                     output_id_to_name_in_program=output_id_to_name_in_program,
@@ -308,12 +286,12 @@ class LazilyCompilingFunctionCaller:
 
     def __call__(self, *args: Any, **kwargs: Any) -> Any:
         """
-        Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s
+        Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s
         function application on *args*.
 
-        Before applying :attr:`~LazilyCompilingFunctionCaller.f`, it is compiled
+        Before applying :attr:`~BaseLazilyCompilingFunctionCaller.f`, it is compiled
         to a :mod:`pytato` DAG that would apply
-        :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
+        :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
         The intermediary pytato DAG for *args* is memoized in *self*.
         """
         arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
@@ -370,23 +348,127 @@ class LazilyCompilingFunctionCaller:
         return compiled_func(arg_id_to_arg)
 
 
-def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
-    from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
+class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
+    @property
+    def compiled_function_returning_array_container_class(
+            self) -> Type["CompiledFunction"]:
+        return CompiledPyOpenCLFunctionReturningArrayContainer
 
+    @property
+    def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]:
+        return CompiledPyOpenCLFunctionReturningArray
+
+    def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays):
+        from pytato.target.loopy import BoundPyOpenCLProgram
+
+        import loopy as lp
+
+        with ProcessLogger(logger, "transform_dag"):
+            pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
+
+        name_in_program_to_tags = {
+            name: out.tags
+            for name, out in pt_dict_of_named_arrays._data.items()}
+        name_in_program_to_axes = {
+            name: out.axes
+            for name, out in pt_dict_of_named_arrays._data.items()}
+
+        with ProcessLogger(logger, "generate_loopy"):
+            pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
+                                               options=lp.Options(
+                                                   return_dict=True,
+                                                   no_numpy=True),
+                                               # pylint: disable=no-member
+                                               cl_device=self.actx.queue.device)
+            assert isinstance(pytato_program, BoundPyOpenCLProgram)
+
+        with ProcessLogger(logger, "transform_loopy_program"):
+
+            pytato_program = (pytato_program
+                              .with_transformed_program(
+                                  lambda x: x.with_kernel(
+                                      x.default_entrypoint
+                                      .tagged(FromArrayContextCompile()))))
+
+            pytato_program = (pytato_program
+                              .with_transformed_program(self
+                                                        .actx
+                                                        .transform_loopy_program))
+
+        return pytato_program, name_in_program_to_tags, name_in_program_to_axes
+
+
+# {{{ preserve back compat
+
+class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller):
+    def __new__(cls, *args, **kwargs):
+        from warnings import warn
+        warn("LazilyCompilingFunctionCaller has been renamed to"
+             " LazilyPyOpenCLCompilingFunctionCaller. This will be"
+             " an error in 2023.", DeprecationWarning, stacklevel=2)
+        return super(LazilyCompilingFunctionCaller, cls).__new__(cls)
+
+    def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
+        from warnings import warn
+        warn("_dag_to_transformed_loopy_prg has been renamed to"
+             " _dag_to_transformed_pytato_prg. This will be"
+             " an error in 2023.", DeprecationWarning, stacklevel=2)
+        return super()._dag_to_transformed_pytato_prg(dict_of_named_arrays)
+
+# }}}
+
+
+class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
+    @property
+    def compiled_function_returning_array_container_class(
+            self) -> Type["CompiledFunction"]:
+        return CompiledJAXFunctionReturningArrayContainer
+
+    @property
+    def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]:
+        return CompiledJAXFunctionReturningArray
+
+    def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays):
+
+        with ProcessLogger(logger, "transform_dag"):
+            pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
+
+        name_in_program_to_tags = {
+            name: out.tags
+            for name, out in pt_dict_of_named_arrays._data.items()}
+        name_in_program_to_axes = {
+            name: out.axes
+            for name, out in pt_dict_of_named_arrays._data.items()}
+
+        with ProcessLogger(logger, "generate_jax"):
+            pytato_program = pt.generate_jax(pt_dict_of_named_arrays, jit=True)
+
+        return pytato_program, name_in_program_to_tags, name_in_program_to_axes
+
+
+def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
     input_kwargs_for_loopy = {}
 
     for arg_id, arg in arg_id_to_arg.items():
         if np.isscalar(arg):
-            arg = cla.to_device(actx.queue, np.array(arg))
+            if isinstance(actx, PytatoPyOpenCLArrayContext):
+                import pyopencl.array as cla
+                arg = cla.to_device(actx.queue, np.array(arg))
+            elif isinstance(actx, PytatoJAXArrayContext):
+                import jax
+                arg = jax.device_put(arg)
+            else:
+                raise NotImplementedError(type(actx))
+
         elif isinstance(arg, pt.array.DataWrapper):
-            # got a Datwwrapper => simply gets its data
+            # got a Datawrapper => simply gets its data
             arg = arg.data
-        elif isinstance(arg, TaggableCLArray):
+        elif isinstance(arg, actx.frozen_array_types):
             # got a frozen array  => do nothing
             pass
         elif isinstance(arg, pt.Array):
             # got an array expression => evaluate it
-            arg = actx.freeze(arg).with_queue(actx.queue)
+            arg = actx.freeze(arg)
         else:
             raise NotImplementedError(type(arg))
 
@@ -395,10 +477,19 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
     return input_kwargs_for_loopy
 
 
+def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
+    from warnings import warn
+    warn("_args_to_cl_buffer has been renamed to"
+         " _args_to_device_buffers. This will be"
+         " an error in 2023.", DeprecationWarning, stacklevel=2)
+    return _args_to_device_buffers(actx, input_id_to_name_in_program,
+                                   arg_id_to_arg)
+
+
 class CompiledFunction(abc.ABC):
     """
     A callable which captures the :class:`pytato.target.BoundProgram`  resulting
-    from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of
+    from calling :attr:`~BaseLazilyCompilingFunctionCaller.f` with a given set of
     input types, and generating :mod:`loopy` IR from it.
 
     .. attribute:: pytato_program
@@ -407,7 +498,7 @@ class CompiledFunction(abc.ABC):
 
         A mapping from input id to the placeholder name in
         :attr:`CompiledFunction.pytato_program`. Input id is represented as the
-        position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented
+        position of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s argument augmented
         with the leaf array's key if the argument is an array container.
 
 
@@ -425,7 +516,7 @@ class CompiledFunction(abc.ABC):
 
 
 @dataclass(frozen=True)
-class CompiledFunctionReturningArrayContainer(CompiledFunction):
+class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
     """
     .. attribute:: output_id_to_name_in_program
 
@@ -452,7 +543,7 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
         from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
         from .utils import get_cl_axes_from_pt_axes
 
-        input_kwargs_for_loopy = _args_to_cl_buffers(
+        input_kwargs_for_loopy = _args_to_device_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
 
         evt, out_dict = self.pytato_program(queue=self.actx.queue,
@@ -477,7 +568,7 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
 
 
 @dataclass(frozen=True)
-class CompiledFunctionReturningArray(CompiledFunction):
+class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
     """
     .. attribute:: output_name_in_program
 
@@ -494,7 +585,7 @@ class CompiledFunctionReturningArray(CompiledFunction):
         from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
         from .utils import get_cl_axes_from_pt_axes
 
-        input_kwargs_for_loopy = _args_to_cl_buffers(
+        input_kwargs_for_loopy = _args_to_device_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
 
         evt, out_dict = self.pytato_program(queue=self.actx.queue,
@@ -510,3 +601,66 @@ class CompiledFunctionReturningArray(CompiledFunction):
                                                  axes=get_cl_axes_from_pt_axes(
                                                      self.output_axes),
                                                  tags=self.output_tags))
+
+
+@dataclass(frozen=True)
+class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
+    """
+    .. attribute:: output_id_to_name_in_program
+
+        A mapping from output id to the name of
+        :class:`pytato.array.NamedArray` in
+        :attr:`CompiledFunction.pytato_program`. Output id is represented by
+        the key of a leaf array in the array container
+        :attr:`CompiledFunction.output_template`.
+
+    .. attribute:: output_template
+
+       An instance of :class:`arraycontext.ArrayContainer` that is the return
+       type of the callable.
+    """
+    actx: PytatoJAXArrayContext
+    pytato_program: pt.target.BoundProgram
+    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
+    name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]]
+    output_template: ArrayContainer
+
+    def __call__(self, arg_id_to_arg) -> ArrayContainer:
+        input_kwargs_for_loopy = _args_to_device_buffers(
+                self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
+
+        out_dict = self.pytato_program(**input_kwargs_for_loopy)
+
+        def to_output_template(keys, _):
+            return self.actx.thaw(
+                out_dict[self.output_id_to_name_in_program[keys]]
+                .block_until_ready()
+            )
+
+        return rec_keyed_map_array_container(to_output_template,
+                                             self.output_template)
+
+
+@dataclass(frozen=True)
+class CompiledJAXFunctionReturningArray(CompiledFunction):
+    """
+    .. attribute:: output_name_in_program
+
+        Name of the output array in the program.
+    """
+    actx: PytatoJAXArrayContext
+    pytato_program: pt.target.BoundProgram
+    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    output_tags: FrozenSet[Tag]
+    output_axes: Tuple[pt.Axis, ...]
+    output_name: str
+
+    def __call__(self, arg_id_to_arg) -> ArrayContainer:
+        input_kwargs_for_loopy = _args_to_device_buffers(
+                self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
+
+        evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)
+
+        return self.actx.thaw(out_dict[self.output_name])
-- 
GitLab