From 1cc7f5eaba8719bcfe582fca6bfc02e0f2c9c8e0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Fri, 21 Jan 2022 00:27:12 -0600 Subject: [PATCH] Generalize Pytato Array Context to allow multiple targets Also implements PytatoJAXTarget. Co-authored-by: Alexandru Fikl <alexfikl@gmail.com> --- arraycontext/__init__.py | 5 +- arraycontext/impl/pytato/__init__.py | 262 +++++++++++++++++++----- arraycontext/impl/pytato/compile.py | 290 ++++++++++++++++++++------- 3 files changed, 441 insertions(+), 116 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index d9861d6..2dc4abd 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -65,8 +65,7 @@ from .container.traversal import ( from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import (PytatoPyOpenCLArrayContext, - PytatoJAXArrayContext, - _BasePytatoArrayContext) + PytatoJAXArrayContext) from .impl.jax import EagerJAXArrayContext from .pytest import ( @@ -106,7 +105,7 @@ __all__ = ( "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", - "PytatoJAXArrayContext", "_BasePytatoArrayContext", + "PytatoJAXArrayContext", "EagerJAXArrayContext", "make_loopy_program", diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 216bd1e..909e432 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -10,6 +10,7 @@ JIT-compile and execute the array expressions. Following :mod:`pytato`-based array context are provided: .. autoclass:: PytatoPyOpenCLArrayContext +.. autoclass:: PytatoJAXArrayContext Compiling a python callable @@ -44,14 +45,85 @@ THE SOFTWARE. from arraycontext.context import ArrayContext, _ScalarLike from arraycontext.container.traversal import rec_map_array_container import numpy as np -from typing import Any, Callable, Union, TYPE_CHECKING +from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type from pytools.tag import ToTagSetConvertible +import abc if TYPE_CHECKING: import pytato -class PytatoPyOpenCLArrayContext(ArrayContext): +class _BasePytatoArrayContext(ArrayContext, abc.ABC): + """ + An abstract :class:`ArrayContext` that uses :mod:`pytato` data types to + represent. + + .. automethod:: __init__ + + .. automethod:: transform_dag + + .. automethod:: compile + """ + def __init__(self): + super().__init__() + self._freeze_prg_cache = {} + self._dag_transform_cache = {} + + def _get_fake_numpy_namespace(self): + from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace + return PytatoFakeNumpyNamespace(self) + + def empty(self, shape, dtype): + raise ValueError(f"{type(self).__name__} does not support empty") + + def zeros(self, shape, dtype): + import pytato as pt + return pt.zeros(shape, dtype) + + def transform_dag(self, dag: "pytato.DictOfNamedArrays" + ) -> "pytato.DictOfNamedArrays": + """ + Returns a transformed version of *dag*. Sub-classes are supposed to + override this method to implement context-specific transformations on + *dag* (most likely to perform domain-specific optimizations). Every + :mod:`pytato` DAG that is compiled to a GPU-kernel is + passed through this routine. + + :arg dag: An instance of :class:`pytato.DictOfNamedArrays` + :returns: A transformed version of *dag*. + """ + return dag + + def transform_loopy_program(self, t_unit): + raise ValueError(f"{type(self)} does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + @abc.abstractproperty + def frozen_array_types(self) -> Tuple[Type, ...]: + """ + Returns valid frozen array types for the array context. + """ + pass + + @abc.abstractmethod + def einsum(self, spec, *args, arg_names=None, tagged=()): + pass + + @property + def permits_inplace_modification(self): + return False + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True + + +class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ A :class:`ArrayContext` that uses :mod:`pytato` data types to represent the arrays targeting OpenCL for offloading operations. @@ -79,28 +151,15 @@ class PytatoPyOpenCLArrayContext(ArrayContext): self.queue = queue self.allocator = allocator self.array_types = (pt.Array, cla.Array) - self._freeze_prg_cache = {} - self._dag_transform_cache = {} # unused, but necessary to keep the context alive self.context = self.queue.context - def _get_fake_numpy_namespace(self): - from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace - return PytatoFakeNumpyNamespace(self) - # {{{ ArrayContext interface def clone(self): return type(self)(self.queue, self.allocator) - def empty(self, shape, dtype): - raise ValueError("PytatoPyOpenCLArrayContext does not support empty") - - def zeros(self, shape, dtype): - import pytato as pt - return pt.zeros(shape, dtype) - def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): import pytato as pt import pyopencl.array as cla @@ -114,6 +173,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext): cl_array = self.freeze(array) return cl_array.get(queue=self.queue) + @property + def frozen_array_types(self) -> Tuple[Type, ...]: + import pyopencl.array as cla + return (cla.Array, ) + def call_loopy(self, program, **kwargs): import pytato as pt from pytato.scalar_expr import SCALAR_CLASSES @@ -167,8 +231,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): axes=get_cl_axes_from_pt_axes(array.axes), tags=array.tags) if not isinstance(array, pt.Array): - raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " - f"non-pytato array of type '{type(array)}'") + raise TypeError(f"{type(self).__name__}.freeze invoked " + f"with non-pytato array of type '{type(array)}'") # {{{ early exit for 0-sized arrays @@ -227,7 +291,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): elif isinstance(array, cl_array.Array): array = to_tagged_cl_array(array, axes=None, tags=frozenset()) else: - raise TypeError("PytatoPyOpenCLArrayContext.thaw expects " + raise TypeError(f"{type(self).__name__}.thaw expects " "'TaggableCLArray' or 'cl.array.Array' got " f"{type(array)}.") @@ -238,30 +302,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext): # }}} def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: - from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller - return LazilyCompilingFunctionCaller(self, f) - - def transform_loopy_program(self, t_unit): - raise ValueError("PytatoPyOpenCLArrayContext does not implement " - "transform_loopy_program. Sub-classes are supposed " - "to implement it.") + from .compile import LazilyPyOpenCLCompilingFunctionCaller + return LazilyPyOpenCLCompilingFunctionCaller(self, f) def transform_dag(self, dag: "pytato.DictOfNamedArrays" ) -> "pytato.DictOfNamedArrays": - """ - Returns a transformed version of *dag*. Sub-classes are supposed to - override this method to implement context-specific transformations on - *dag* (most likely to perform domain-specific optimizations). Every - :mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is - passed through this routine. - - :arg dag: An instance of :class:`pytato.DictOfNamedArrays` - :returns: A transformed version of *dag*. - """ import pytato as pt - dag = pt.transform.materialize_with_mpms(dag) - return dag def tag(self, tags: ToTagSetConvertible, array): @@ -315,14 +362,139 @@ class PytatoPyOpenCLArrayContext(ArrayContext): for name, arg in zip(arg_names, args) ]) - @property - def permits_inplace_modification(self): - return False - @property - def supports_nonscalar_broadcasting(self): - return True +class PytatoJAXArrayContext(_BasePytatoArrayContext): + """ + An arraycontext that uses :mod:`pytato` to represent the thawed state of + the arrays and compiles the expressions using + :class:`pytato.target.python.JAXPythonTarget`. + """ + + def __init__(self): + import pytato as pt + from jax.numpy import DeviceArray + super().__init__() + self.array_types = (pt.Array, DeviceArray) + + def clone(self): + return type(self)() + + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): + import jax + import pytato as pt + return pt.make_data_wrapper(jax.device_put(array)) + + def to_numpy(self, array): + if np.isscalar(array): + return array + + import jax + return jax.device_get(self.freeze(array)) @property - def permits_advanced_indexing(self): - return True + def frozen_array_types(self) -> Tuple[Type, ...]: + from jax.numpy import DeviceArray + return (DeviceArray, ) + + def call_loopy(self, program, **kwargs): + raise ValueError(f"{type(self)} does not support calling loopy.") + + def freeze(self, array): + import pytato as pt + from jax.numpy import DeviceArray + + if isinstance(array, DeviceArray): + return array.block_until_ready() + if not isinstance(array, pt.Array): + raise TypeError(f"{type(self)}.freeze invoked with " + f"non-pytato array of type '{type(array)}'") + + from arraycontext.impl.pytato.utils import _normalize_pt_expr + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( + {"_actx_out": array}) + + normalized_expr, bound_arguments = _normalize_pt_expr( + pt_dict_of_named_arrays) + + try: + pt_prg = self._freeze_prg_cache[normalized_expr] + except KeyError: + pt_prg = pt.generate_jax(self.transform_dag(normalized_expr), + jit=True) + self._freeze_prg_cache[normalized_expr] = pt_prg + + assert len(pt_prg.bound_arguments) == 0 + out_dict = pt_prg(**bound_arguments) + + return out_dict["_actx_out"].block_until_ready() + + def thaw(self, array): + import pytato as pt + + if not isinstance(array, self.frozen_array_types): + raise TypeError(f"{type(self)}.thaw expects jax device arrays, got " + f"{type(array)}") + + return pt.make_data_wrapper(array) + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + from .compile import LazilyJAXCompilingFunctionCaller + return LazilyJAXCompilingFunctionCaller(self, f) + + def tag(self, tags: ToTagSetConvertible, array): + import pytato as pt + from jax.numpy import DeviceArray + + def _rec_tag(ary): + if isinstance(ary, DeviceArray): + return ary + else: + assert isinstance(ary, pt.Array) + return ary.tagged(tags) + + return rec_map_array_container(_rec_tag, array) + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + import pytato as pt + from jax.numpy import DeviceArray + + def _rec_tag_axis(ary): + if isinstance(ary, DeviceArray): + return ary + else: + assert isinstance(ary, pt.Array) + return ary.with_tagged_axis(iaxis, tags) + + return rec_map_array_container(_rec_tag_axis, + array) + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import pytato as pt + from jax.numpy import DeviceArray + if arg_names is None: + arg_names = (None,) * len(args) + + def preprocess_arg(name, arg): + if isinstance(arg, DeviceArray): + ary = self.thaw(arg) + else: + assert isinstance(arg, pt.Array) + ary = arg + + if name is not None: + from pytato.tags import PrefixNamed + + # Tagging Placeholders with naming-related tags is pointless: + # They already have names. It's also counterproductive, as + # multiple placeholders with the same name that are not + # also the same object are not allowed, and this would produce + # a different Placeholder object of the same name. + if not isinstance(ary, pt.Placeholder): + ary = ary.tagged(PrefixNamed(name)) + + return ary + + return pt.einsum(spec, *[ + preprocess_arg(name, arg) + for name, arg in zip(arg_names, args) + ]) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 8b20a7c..129a2c4 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -1,6 +1,7 @@ """ -.. currentmodule:: arraycontext.impl.pytato.compile -.. autoclass:: LazilyCompilingFunctionCaller +.. autoclass:: BaseLazilyCompilingFunctionCaller +.. autoclass:: LazilyPyOpenCLCompilingFunctionCaller +.. autoclass:: LazilyJAXCompilingFunctionCaller .. autoclass:: CompiledFunction .. autoclass:: FromArrayContextCompile """ @@ -30,16 +31,17 @@ THE SOFTWARE. from arraycontext.container import (ArrayContainer, is_array_container_type, ArrayT) -from arraycontext import PytatoPyOpenCLArrayContext +from arraycontext.impl.pytato import (_BasePytatoArrayContext, + PytatoJAXArrayContext, + PytatoPyOpenCLArrayContext) from arraycontext.container.traversal import rec_keyed_map_array_container import abc import numpy as np -from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet +from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet, Type from dataclasses import dataclass, field from pyrsistent import pmap, PMap -import pyopencl.array as cla import pytato as pt import itertools from pytools.tag import Tag @@ -65,7 +67,7 @@ class FromArrayContextCompile(Tag): class AbstractInputDescriptor: """ - Used internally in :class:`LazilyCompilingFunctionCaller` to characterize + Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize an input. """ def __eq__(self, other): @@ -90,7 +92,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor): def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: """ - Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Stringifies an + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an array-container's component's key. Goals of this routine: * No two different keys should have the same stringification @@ -118,7 +120,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], AbstractInputDescriptor]\ ]": """ - Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Extracts + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts mappings from argument id to argument values and from argument id to :class:`AbstractInputDescriptor`. See :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's @@ -190,9 +192,9 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): """ - Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the placeholder version of an argument to - :attr:`LazilyCompilingFunctionCaller.f`. + :attr:`BaseLazilyCompilingFunctionCaller.f`. """ if np.isscalar(arg): name = arg_id_to_name[(kw,)] @@ -221,12 +223,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): @dataclass -class LazilyCompilingFunctionCaller: +class BaseLazilyCompilingFunctionCaller: """ - Records a side-effect-free callable - :attr:`LazilyCompilingFunctionCaller.f` that can be specialized for the - input types with which :meth:`LazilyCompilingFunctionCaller.__call__` is - invoked. + Records a side-effect-free callable :attr:`f` that can be specialized for + the input types with which :meth:`__call__` is invoked. .. attribute:: f @@ -235,48 +235,26 @@ class LazilyCompilingFunctionCaller: .. automethod:: __call__ """ - actx: PytatoPyOpenCLArrayContext + actx: _BasePytatoArrayContext f: Callable[..., Any] program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", "CompiledFunction"] = field(default_factory=lambda: {}) - def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): - from pytato.target.loopy import BoundPyOpenCLProgram - - import loopy as lp - - with ProcessLogger(logger, "transform_dag"): - pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) - - name_in_program_to_tags = { - name: out.tags - for name, out in pt_dict_of_named_arrays._data.items()} - name_in_program_to_axes = { - name: out.axes - for name, out in pt_dict_of_named_arrays._data.items()} - - with ProcessLogger(logger, "generate_loopy"): - pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, - options=lp.Options( - return_dict=True, - no_numpy=True), - cl_device=self.actx.queue.device) - assert isinstance(pytato_program, BoundPyOpenCLProgram) + # {{{ abstract interface - with ProcessLogger(logger, "transform_loopy_program"): + def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays): + raise NotImplementedError - pytato_program = (pytato_program - .with_transformed_program( - lambda x: x.with_kernel( - x.default_entrypoint - .tagged(FromArrayContextCompile())))) + @property + def compiled_function_returning_array_container_class( + self) -> Type["CompiledFunction"]: + raise NotImplementedError - pytato_program = (pytato_program - .with_transformed_program(self - .actx - .transform_loopy_program)) + @property + def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + raise NotImplementedError - return pytato_program, name_in_program_to_tags, name_in_program_to_axes + # }}} def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, input_id_to_name_in_program, output_id_to_name_in_program, @@ -286,8 +264,8 @@ class LazilyCompilingFunctionCaller: dict_of_named_arrays = pt.make_dict_of_named_arrays( {output_id: ary_or_dict_of_named_arrays}) pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( - self._dag_to_transformed_loopy_prg(dict_of_named_arrays)) - return CompiledFunctionReturningArray( + self._dag_to_transformed_pytato_prg(dict_of_named_arrays)) + return self.compiled_function_returning_array_class( self.actx, pytato_program, input_id_to_name_in_program=input_id_to_name_in_program, output_tags=name_in_program_to_tags[output_id], @@ -295,8 +273,8 @@ class LazilyCompilingFunctionCaller: output_name=output_id) elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( - self._dag_to_transformed_loopy_prg(ary_or_dict_of_named_arrays)) - return CompiledFunctionReturningArrayContainer( + self._dag_to_transformed_pytato_prg(ary_or_dict_of_named_arrays)) + return self.compiled_function_returning_array_container_class( self.actx, pytato_program, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, @@ -308,12 +286,12 @@ class LazilyCompilingFunctionCaller: def __call__(self, *args: Any, **kwargs: Any) -> Any: """ - Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s + Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s function application on *args*. - Before applying :attr:`~LazilyCompilingFunctionCaller.f`, it is compiled + Before applying :attr:`~BaseLazilyCompilingFunctionCaller.f`, it is compiled to a :mod:`pytato` DAG that would apply - :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. + :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( @@ -370,23 +348,127 @@ class LazilyCompilingFunctionCaller: return compiled_func(arg_id_to_arg) -def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): - from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray +class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): + @property + def compiled_function_returning_array_container_class( + self) -> Type["CompiledFunction"]: + return CompiledPyOpenCLFunctionReturningArrayContainer + @property + def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + return CompiledPyOpenCLFunctionReturningArray + + def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays): + from pytato.target.loopy import BoundPyOpenCLProgram + + import loopy as lp + + with ProcessLogger(logger, "transform_dag"): + pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) + + name_in_program_to_tags = { + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_axes = { + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()} + + with ProcessLogger(logger, "generate_loopy"): + pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, + options=lp.Options( + return_dict=True, + no_numpy=True), + # pylint: disable=no-member + cl_device=self.actx.queue.device) + assert isinstance(pytato_program, BoundPyOpenCLProgram) + + with ProcessLogger(logger, "transform_loopy_program"): + + pytato_program = (pytato_program + .with_transformed_program( + lambda x: x.with_kernel( + x.default_entrypoint + .tagged(FromArrayContextCompile())))) + + pytato_program = (pytato_program + .with_transformed_program(self + .actx + .transform_loopy_program)) + + return pytato_program, name_in_program_to_tags, name_in_program_to_axes + + +# {{{ preserve back compat + +class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): + def __new__(cls, *args, **kwargs): + from warnings import warn + warn("LazilyCompilingFunctionCaller has been renamed to" + " LazilyPyOpenCLCompilingFunctionCaller. This will be" + " an error in 2023.", DeprecationWarning, stacklevel=2) + return super(LazilyCompilingFunctionCaller, cls).__new__(cls) + + def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): + from warnings import warn + warn("_dag_to_transformed_loopy_prg has been renamed to" + " _dag_to_transformed_pytato_prg. This will be" + " an error in 2023.", DeprecationWarning, stacklevel=2) + return super()._dag_to_transformed_pytato_prg(dict_of_named_arrays) + +# }}} + + +class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): + @property + def compiled_function_returning_array_container_class( + self) -> Type["CompiledFunction"]: + return CompiledJAXFunctionReturningArrayContainer + + @property + def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + return CompiledJAXFunctionReturningArray + + def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays): + + with ProcessLogger(logger, "transform_dag"): + pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) + + name_in_program_to_tags = { + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_axes = { + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()} + + with ProcessLogger(logger, "generate_jax"): + pytato_program = pt.generate_jax(pt_dict_of_named_arrays, jit=True) + + return pytato_program, name_in_program_to_tags, name_in_program_to_axes + + +def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): if np.isscalar(arg): - arg = cla.to_device(actx.queue, np.array(arg)) + if isinstance(actx, PytatoPyOpenCLArrayContext): + import pyopencl.array as cla + arg = cla.to_device(actx.queue, np.array(arg)) + elif isinstance(actx, PytatoJAXArrayContext): + import jax + arg = jax.device_put(arg) + else: + raise NotImplementedError(type(actx)) + elif isinstance(arg, pt.array.DataWrapper): - # got a Datwwrapper => simply gets its data + # got a Datawrapper => simply gets its data arg = arg.data - elif isinstance(arg, TaggableCLArray): + elif isinstance(arg, actx.frozen_array_types): # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): # got an array expression => evaluate it - arg = actx.freeze(arg).with_queue(actx.queue) + arg = actx.freeze(arg) else: raise NotImplementedError(type(arg)) @@ -395,10 +477,19 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): return input_kwargs_for_loopy +def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): + from warnings import warn + warn("_args_to_cl_buffer has been renamed to" + " _args_to_device_buffers. This will be" + " an error in 2023.", DeprecationWarning, stacklevel=2) + return _args_to_device_buffers(actx, input_id_to_name_in_program, + arg_id_to_arg) + + class CompiledFunction(abc.ABC): """ A callable which captures the :class:`pytato.target.BoundProgram` resulting - from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of + from calling :attr:`~BaseLazilyCompilingFunctionCaller.f` with a given set of input types, and generating :mod:`loopy` IR from it. .. attribute:: pytato_program @@ -407,7 +498,7 @@ class CompiledFunction(abc.ABC): A mapping from input id to the placeholder name in :attr:`CompiledFunction.pytato_program`. Input id is represented as the - position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented + position of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s argument augmented with the leaf array's key if the argument is an array container. @@ -425,7 +516,7 @@ class CompiledFunction(abc.ABC): @dataclass(frozen=True) -class CompiledFunctionReturningArrayContainer(CompiledFunction): +class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): """ .. attribute:: output_id_to_name_in_program @@ -452,7 +543,7 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from .utils import get_cl_axes_from_pt_axes - input_kwargs_for_loopy = _args_to_cl_buffers( + input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, @@ -477,7 +568,7 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): @dataclass(frozen=True) -class CompiledFunctionReturningArray(CompiledFunction): +class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): """ .. attribute:: output_name_in_program @@ -494,7 +585,7 @@ class CompiledFunctionReturningArray(CompiledFunction): from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from .utils import get_cl_axes_from_pt_axes - input_kwargs_for_loopy = _args_to_cl_buffers( + input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, @@ -510,3 +601,66 @@ class CompiledFunctionReturningArray(CompiledFunction): axes=get_cl_axes_from_pt_axes( self.output_axes), tags=self.output_tags)) + + +@dataclass(frozen=True) +class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): + """ + .. attribute:: output_id_to_name_in_program + + A mapping from output id to the name of + :class:`pytato.array.NamedArray` in + :attr:`CompiledFunction.pytato_program`. Output id is represented by + the key of a leaf array in the array container + :attr:`CompiledFunction.output_template`. + + .. attribute:: output_template + + An instance of :class:`arraycontext.ArrayContainer` that is the return + type of the callable. + """ + actx: PytatoJAXArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] + name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + output_template: ArrayContainer + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + out_dict = self.pytato_program(**input_kwargs_for_loopy) + + def to_output_template(keys, _): + return self.actx.thaw( + out_dict[self.output_id_to_name_in_program[keys]] + .block_until_ready() + ) + + return rec_keyed_map_array_container(to_output_template, + self.output_template) + + +@dataclass(frozen=True) +class CompiledJAXFunctionReturningArray(CompiledFunction): + """ + .. attribute:: output_name_in_program + + Name of the output array in the program. + """ + actx: PytatoJAXArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_tags: FrozenSet[Tag] + output_axes: Tuple[pt.Axis, ...] + output_name: str + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) + + return self.actx.thaw(out_dict[self.output_name]) -- GitLab