From 70f251f20af68f3828e9e6953a3a596434fecff6 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 26 Jun 2022 09:03:21 +0300 Subject: [PATCH] pytato: add more array container support --- arraycontext/impl/pytato/__init__.py | 479 ++++++++++++++++----------- arraycontext/impl/pytato/compile.py | 2 +- 2 files changed, 295 insertions(+), 186 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index ed3bff8..8d7e042 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -42,17 +42,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import abc import sys -from arraycontext.context import ArrayContext, ScalarLike -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context) -from arraycontext.metadata import NameHint +from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional, + TYPE_CHECKING) import numpy as np -from typing import (Any, Callable, Union, TYPE_CHECKING, Tuple, Type, FrozenSet, - Dict, Optional) from pytools.tag import ToTagSetConvertible, normalize_tags, Tag -import abc + +from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike +from arraycontext.container.traversal import (rec_map_array_container, + with_array_context) +from arraycontext.metadata import NameHint if TYPE_CHECKING: import pytato @@ -105,9 +106,11 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): .. automethod:: compile """ - def __init__(self, - *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]] - = None) -> None: + + def __init__( + self, *, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None + ) -> None: """ :arg compile_trace_callback: A function of three arguments *(what, stage, ir)*, where *what* identifies the object @@ -116,9 +119,10 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): representation. This interface should be considered unstable. """ + super().__init__() + import pytato as pt import loopy as lp - super().__init__() self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} self._dag_transform_cache: Dict[ pt.DictOfNamedArrays, @@ -136,13 +140,30 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace return PytatoFakeNumpyNamespace(self) + @abc.abstractproperty + def _frozen_array_types(self) -> Tuple[Type, ...]: + """ + Returns valid frozen array types for the array context. + """ + + # {{{ ArrayContext interface + def empty(self, shape, dtype): - raise ValueError(f"{type(self).__name__} does not support empty") + raise NotImplementedError( + f"{type(self).__name__}.empty is not supported") def zeros(self, shape, dtype): import pytato as pt return pt.zeros(shape, dtype) + def empty_like(self, ary): + raise NotImplementedError( + f"{type(self).__name__}.empty_like is not supported") + + # }}} + + # {{{ compilation + def transform_dag(self, dag: "pytato.DictOfNamedArrays" ) -> "pytato.DictOfNamedArrays": """ @@ -158,21 +179,18 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): return dag def transform_loopy_program(self, t_unit): - raise ValueError(f"{type(self)} does not implement " - "transform_loopy_program. Sub-classes are supposed " - "to implement it.") - - @abc.abstractproperty - def frozen_array_types(self) -> Tuple[Type, ...]: - """ - Returns valid frozen array types for the array context. - """ - pass + raise ValueError( + f"{type(self).__name__} does not implement transform_loopy_program. " + "Sub-classes are supposed to implement it.") @abc.abstractmethod def einsum(self, spec, *args, arg_names=None, tagged=()): pass + # }}} + + # {{{ properties + @property def permits_inplace_modification(self): return False @@ -185,6 +203,8 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): def permits_advanced_indexing(self): return True + # }}} + # }}} @@ -210,10 +230,10 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): .. automethod:: compile """ - def __init__(self, queue: "cl.CommandQueue", allocator=None, - *, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] - = None) -> None: + def __init__( + self, queue: "cl.CommandQueue", allocator=None, *, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None + ) -> None: """ :arg compile_trace_callback: A function of three arguments *(what, stage, ir)*, where *what* identifies the object @@ -232,63 +252,82 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): # unused, but necessary to keep the context alive self.context = self.queue.context - # {{{ ArrayContext interface - - def clone(self): - return type(self)(self.queue, self.allocator) + @property + def _frozen_array_types(self) -> Tuple[Type, ...]: + import pyopencl.array as cla + return (cla.Array,) - def from_numpy(self, array: Union[np.ndarray, ScalarLike]): + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: import pytato as pt - import pyopencl.array as cla - cl_array = cla.to_device(self.queue, array) - return pt.make_data_wrapper(cl_array) + import arraycontext.impl.pyopencl.taggable_cl_array as tga - def to_numpy(self, array): - if np.isscalar(array): - return array + if allowed_types is None: + allowed_types = (pt.Array, tga.TaggableCLArray) - cl_array = self.freeze(array) - return cl_array.get(queue=self.queue) + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif not strict and isinstance(ary, self._frozen_array_types): + from warnings import warn + warn(f"Invoking {type(self).__name__}.{func.__name__[1:]} with" + f" {type(ary).__name__} will be unsupported in 2023. Use" + " 'to_tagged_cl_array' to convert instances to" + " TaggableCLArray.", DeprecationWarning, stacklevel=2) + + return func(tga.to_tagged_cl_array(ary)) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") - @property - def frozen_array_types(self) -> Tuple[Type, ...]: - import pyopencl.array as cla - return (cla.Array, ) + return rec_map_array_container(_wrapper, array) - def call_loopy(self, program, **kwargs): - import pytato as pt - from pytato.scalar_expr import SCALAR_CLASSES - from pytato.loopy import call_loopy - from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + # {{{ ArrayContext interface - entrypoint = program.default_entrypoint.name + def zeros_like(self, ary): + def _zeros_like(array): + return self.zeros(array.shape, array.dtype) - # {{{ preprocess args + return self._rec_map_container(_zeros_like, ary, default_scalar=0) - processed_kwargs = {} + def from_numpy(self, array): + import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga - for kw, arg in sorted(kwargs.items()): - if isinstance(arg, (pt.Array,) + SCALAR_CLASSES): - pass - elif isinstance(arg, TaggableCLArray): - arg = self.thaw(arg) - else: - raise ValueError(f"call_loopy argument '{kw}' expected to be an" - " instance of 'pytato.Array', 'Number' or" - f"'TaggableCLArray', got '{type(arg)}'") + def _from_numpy(ary): + return pt.make_data_wrapper( + tga.to_device(self.queue, ary, allocator=self.allocator) + ) - processed_kwargs[kw] = arg + return with_array_context( + self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True), + actx=self) - # }}} + def to_numpy(self, array): + def _to_numpy(ary): + return ary.get(queue=self.queue) - return call_loopy(program, processed_kwargs, entrypoint) + return with_array_context( + self._rec_map_container(_to_numpy, self.freeze(array)), + actx=None) def freeze(self, array): + if np.isscalar(array): + return array + import pytato as pt import pyopencl.array as cla - import loopy as lp - from arraycontext.context import ArrayT from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.utils import (_normalize_pt_expr, get_cl_axes_from_pt_axes) @@ -296,16 +335,15 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): TaggableCLArray) from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, - pt.Array]] = {} + array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {} key_to_frozen_subary: Dict[str, TaggableCLArray] = {} key_to_pt_arrays: Dict[str, pt.Array] = {} - def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: ArrayT): + def _record_leaf_ary_in_dict( + key: Tuple[Any, ...], + ary: Union[cla.Array, TaggableCLArray, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary - return ary rec_keyed_map_array_container(_record_leaf_ary_in_dict, array) @@ -314,37 +352,37 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): for key, subary in array_as_dict.items(): if isinstance(subary, TaggableCLArray): key_to_frozen_subary[key] = subary.with_queue(None) - elif isinstance(subary, cla.Array): + elif isinstance(subary, self._frozen_array_types): from warnings import warn - warn("Freezing pyopencl.array.Array will be deprecated in 2023." - " Use `to_tagged_cl_array` to convert the array to" - " TaggableCLArray", DeprecationWarning, stacklevel=2) - key_to_frozen_subary[key] = to_tagged_cl_array( - subary.with_queue(None), - axes=None, - tags=frozenset()) + warn(f"Invoking {type(self).__name__}.freeze with" + f" {type(subary).__name__} will be unsupported in 2023. Use" + " `to_tagged_cl_array` to convert instances to TaggableCLArray.", + DeprecationWarning, stacklevel=2) + + key_to_frozen_subary[key] = ( + to_tagged_cl_array(subary.with_queue(None))) elif isinstance(subary, pt.DataWrapper): # trivial freeze. key_to_frozen_subary[key] = to_tagged_cl_array( subary.data, axes=get_cl_axes_from_pt_axes(subary.axes), tags=subary.tags) - else: - if not isinstance(subary, pt.Array): - raise TypeError(f"{type(self).__name__}.freeze invoked " - f"with non-pytato array of type '{type(array)}'") - + elif isinstance(subary, pt.Array): # Don't be tempted to take shortcuts here, e.g. for empty # arrays, as this will inhibit metadata propagation that # may happen in transform_dag below. See # https://github.com/inducer/arraycontext/pull/167#issuecomment-1151877480 key_to_pt_arrays[key] = subary + else: + raise TypeError( + f"{type(self).__name__}.freeze invoked with an unsupported " + f"array type: got '{type(subary).__name__}', but expected one " + f"of {self.array_types}") # }}} pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( - key_to_pt_arrays) - + key_to_pt_arrays) normalized_expr, bound_arguments = _normalize_pt_expr( pt_dict_of_named_arrays) @@ -352,8 +390,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): pt_prg = self._freeze_prg_cache[normalized_expr] except KeyError: try: - transformed_dag, function_name = \ - self._dag_transform_cache[normalized_expr] + transformed_dag, function_name = ( + self._dag_transform_cache[normalized_expr]) except KeyError: transformed_dag = self.transform_dag(normalized_expr) @@ -373,16 +411,16 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): self._dag_transform_cache[normalized_expr] = ( transformed_dag, function_name) + from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS pt_prg = pt.generate_loopy(transformed_dag, - options=lp.Options(return_dict=True, - no_numpy=True), + options=_DEFAULT_LOOPY_OPTIONS, cl_device=self.queue.device, function_name=function_name) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg else: - transformed_dag, function_name = \ - self._dag_transform_cache[normalized_expr] + transformed_dag, function_name = ( + self._dag_transform_cache[normalized_expr]) assert len(pt_prg.bound_arguments) == 0 evt, out_dict = pt_prg(self.queue, **bound_arguments) @@ -391,47 +429,79 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): key_to_frozen_subary = { **key_to_frozen_subary, - **{k: to_tagged_cl_array(v.with_queue(None), - get_cl_axes_from_pt_axes(transformed_dag[k] - .expr - .axes), - transformed_dag[k].expr.tags) + **{k: to_tagged_cl_array( + v.with_queue(None), + axes=get_cl_axes_from_pt_axes(transformed_dag[k].expr.axes), + tags=transformed_dag[k].expr.tags) for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary: ArrayT): + def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] - return with_array_context(rec_keyed_map_array_container(_to_frozen, - array), - actx=None) + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) def thaw(self, array): import pytato as pt from .utils import get_pt_axes_from_cl_axes - from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, - to_tagged_cl_array) - import pyopencl.array as cl_array + import arraycontext.impl.pyopencl.taggable_cl_array as tga - def _rec_thaw(ary): - if isinstance(ary, TaggableCLArray): - pass - elif isinstance(ary, cl_array.Array): - ary = to_tagged_cl_array(ary, axes=None, tags=frozenset()) - else: - raise TypeError(f"{type(self).__name__}.thaw expects " - "'TaggableCLArray' or 'cl.array.Array' got " - f"{type(ary)}.") + def _thaw(ary): return pt.make_data_wrapper(ary.with_queue(self.queue), axes=get_pt_axes_from_cl_axes(ary.axes), tags=ary.tags) - return with_array_context(rec_map_array_container(_rec_thaw, array), - actx=self) + return with_array_context( + self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), + actx=self) + + def tag(self, tags: ToTagSetConvertible, array): + def _tag(ary): + return ary.tagged(_preprocess_array_tags(tags)) + + return self._rec_map_container(_tag, array) + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + def _tag_axis(ary): + return ary.with_tagged_axis(iaxis, tags) + + return self._rec_map_container(_tag_axis, array) # }}} + # {{{ compilation + + def call_loopy(self, program, **kwargs): + import pytato as pt + from pytato.scalar_expr import SCALAR_CLASSES + from pytato.loopy import call_loopy + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + + entrypoint = program.default_entrypoint.name + + # {{{ preprocess args + + processed_kwargs = {} + + for kw, arg in sorted(kwargs.items()): + if isinstance(arg, (pt.Array,) + SCALAR_CLASSES): + pass + elif isinstance(arg, TaggableCLArray): + arg = self.thaw(arg) + else: + raise ValueError(f"call_loopy argument '{kw}' expected to be an" + " instance of 'pytato.Array', 'Number' or" + f"'TaggableCLArray', got '{type(arg)}'") + + processed_kwargs[kw] = arg + + # }}} + + return call_loopy(program, processed_kwargs, entrypoint) + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyPyOpenCLCompilingFunctionCaller return LazilyPyOpenCLCompilingFunctionCaller(self, f) @@ -442,39 +512,30 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): dag = pt.transform.materialize_with_mpms(dag) return dag - def tag(self, tags: ToTagSetConvertible, array): - return rec_map_array_container( - lambda x: x.tagged(_preprocess_array_tags(tags)), - array) - - def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): - return rec_map_array_container( - lambda x: x.with_tagged_axis(iaxis, tags), - array) - def einsum(self, spec, *args, arg_names=None, tagged=()): - import pyopencl.array as cla import pytato as pt - from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, - to_tagged_cl_array) + import arraycontext.impl.pyopencl.taggable_cl_array as tga + if arg_names is None: arg_names = (None,) * len(args) def preprocess_arg(name, arg): - if isinstance(arg, TaggableCLArray): + if isinstance(arg, tga.TaggableCLArray): ary = self.thaw(arg) - elif isinstance(arg, cla.Array): + elif isinstance(arg, self._frozen_array_types): from warnings import warn - warn("Passing pyopencl.array.Array to einsum will be " - "deprecated in 2023." - " Use `to_tagged_cl_array` to convert the array to" - " TaggableCLArray.", DeprecationWarning, stacklevel=2) - ary = self.thaw(to_tagged_cl_array(arg, - axes=None, - tags=frozenset())) - else: - assert isinstance(arg, pt.Array) + warn(f"Invoking {type(self).__name__}.einsum with" + f" {type(arg).__name__} will be unsupported in 2023. Use" + " `to_tagged_cl_array` to convert instances to TaggableCLArray.", + DeprecationWarning, stacklevel=2) + ary = self.thaw(tga.to_tagged_cl_array(arg)) + elif isinstance(arg, pt.Array): ary = arg + else: + raise TypeError( + f"{type(self).__name__}.einsum invoked with an unsupported " + f"array type: got '{type(arg).__name__}', but expected one " + f"of {self.array_types}") if name is not None: # Tagging Placeholders with naming-related tags is pointless: @@ -493,6 +554,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): for name, arg in zip(arg_names, args) ]).tagged(_preprocess_array_tags(tagged)) + def clone(self): + return type(self)(self.queue, self.allocator) + + # }}} + # }}} @@ -521,34 +587,71 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): super().__init__(compile_trace_callback=compile_trace_callback) self.array_types = (pt.Array, DeviceArray) - def clone(self): - return type(self)() + @property + def _frozen_array_types(self) -> Tuple[Type, ...]: + from jax.numpy import DeviceArray + return (DeviceArray, ) + + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: + if allowed_types is None: + allowed_types = self.array_types + + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + # {{{ ArrayContext interface + + def zeros_like(self, ary): + def _zeros_like(array): + return self.zeros(array.shape, array.dtype) - def from_numpy(self, array: Union[np.ndarray, ScalarLike]): + return self._rec_map_container(_zeros_like, ary, default_scalar=0) + + def from_numpy(self, array): import jax import pytato as pt - return pt.make_data_wrapper(jax.device_put(array)) - def to_numpy(self, array): - if np.isscalar(array): - return array + def _from_numpy(ary): + return pt.make_data_wrapper(jax.device_put(ary)) + + return with_array_context( + self._rec_map_container(_from_numpy, array, (np.ndarray,)), + actx=self) + def to_numpy(self, array): import jax - return jax.device_get(self.freeze(array)) - @property - def frozen_array_types(self) -> Tuple[Type, ...]: - from jax.numpy import DeviceArray - return (DeviceArray, ) + def _to_numpy(ary): + return jax.device_get(ary) - def call_loopy(self, program, **kwargs): - raise ValueError(f"{type(self)} does not support calling loopy.") + return with_array_context( + self._rec_map_container(_to_numpy, self.freeze(array)), + actx=None) def freeze(self, array): + if np.isscalar(array): + return array + import pytato as pt from jax.numpy import DeviceArray - from arraycontext.context import ArrayT from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier @@ -557,10 +660,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): key_to_pt_arrays: Dict[str, pt.Array] = {} def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[DeviceArray, pt.Array]): + ary: Union[DeviceArray, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary - return ary rec_keyed_map_array_container(_record_leaf_ary_in_dict, array) @@ -572,12 +674,13 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): elif isinstance(subary, pt.DataWrapper): # trivial freeze. key_to_frozen_subary[key] = subary.data.block_until_ready() - else: - if not isinstance(subary, pt.Array): - raise TypeError(f"{type(self).__name__}.freeze invoked " - f"with non-pytato array of type '{type(array)}'") - + elif isinstance(subary, pt.Array): key_to_pt_arrays[key] = subary + else: + raise TypeError( + f"{type(self).__name__}.freeze invoked with an unsupported " + f"array type: got '{type(subary).__name__}', but expected one " + f"of {self.array_types}") # }}} @@ -593,59 +696,59 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary: ArrayT): + def _to_frozen(key: Tuple[Any, ...], ary) -> DeviceArray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] - return with_array_context(rec_keyed_map_array_container(_to_frozen, - array), - actx=None) + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) def thaw(self, array): import pytato as pt - from jax.numpy import DeviceArray - def _rec_thaw(ary): - if isinstance(ary, DeviceArray): - pass - else: - raise TypeError(f"{type(self).__name__}.thaw expects " - f"'jax.DeviceArray' got {type(ary)}.") + def _thaw(ary): return pt.make_data_wrapper(ary) - return with_array_context(rec_map_array_container(_rec_thaw, array), - actx=self) + return with_array_context( + self._rec_map_container(_thaw, array, self._frozen_array_types), + actx=self) def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyJAXCompilingFunctionCaller return LazilyJAXCompilingFunctionCaller(self, f) def tag(self, tags: ToTagSetConvertible, array): - import pytato as pt from jax.numpy import DeviceArray - def _rec_tag(ary): + def _tag(ary): if isinstance(ary, DeviceArray): return ary else: - assert isinstance(ary, pt.Array) return ary.tagged(_preprocess_array_tags(tags)) - return rec_map_array_container(_rec_tag, array) + return self._rec_map_container(_tag, array) def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): - import pytato as pt from jax.numpy import DeviceArray - def _rec_tag_axis(ary): + def _tag_axis(ary): if isinstance(ary, DeviceArray): return ary else: - assert isinstance(ary, pt.Array) return ary.with_tagged_axis(iaxis, tags) - return rec_map_array_container(_rec_tag_axis, - array) + return self._rec_map_container(_tag_axis, array) + + # }}} + + # {{{ compilation + + def call_loopy(self, program, **kwargs): + raise NotImplementedError( + "Calling loopy on JAX arrays is not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array operations using" + " ArrayContext.np.") def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt @@ -656,9 +759,13 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def preprocess_arg(name, arg): if isinstance(arg, DeviceArray): ary = self.thaw(arg) - else: - assert isinstance(arg, pt.Array) + elif isinstance(arg, pt.Array): ary = arg + else: + raise TypeError( + f"{type(self).__name__}.einsum invoked with an unsupported " + f"array type: got '{type(arg).__name__}', but expected one " + f"of {self.array_types}") if name is not None: # Tagging Placeholders with naming-related tags is pointless: @@ -677,7 +784,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): for name, arg in zip(arg_names, args) ]).tagged(_preprocess_array_tags(tagged)) -# }}} + def clone(self): + return type(self)() +# }}} # vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 9e92adf..53170ca 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -523,7 +523,7 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): elif isinstance(arg, pt.array.DataWrapper): # got a Datawrapper => simply gets its data arg = arg.data - elif isinstance(arg, actx.frozen_array_types): + elif isinstance(arg, actx._frozen_array_types): # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): -- GitLab