diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 2dc4abde4901818a8f86ba9034f87efc565c0210..3665cf6d6f42a03b04c25463c8333f6e1aa664b6 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -61,7 +61,7 @@ from .container.traversal import ( thaw, freeze, flatten, unflatten, flat_size_and_dtype, from_numpy, to_numpy, - outer) + outer, with_array_context) from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import (PytatoPyOpenCLArrayContext, @@ -101,7 +101,7 @@ __all__ = ( "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", "flatten", "unflatten", "flat_size_and_dtype", - "from_numpy", "to_numpy", + "from_numpy", "to_numpy", "with_array_context", "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index cc3bbd50384b0a06ec1cd5add1e73b822aa22cc5..5866ab85331a7a6c32e96dc5d578cba777635e1a 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -65,13 +65,15 @@ THE SOFTWARE. from typing import Any, Callable, Iterable, List, Optional, Union, Tuple from functools import update_wrapper, partial, singledispatch +from warnings import warn import numpy as np from arraycontext.context import ArrayContext, Array, _ScalarLike from arraycontext.container import ( ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, - serialize_container, deserialize_container) + serialize_container, deserialize_container, + get_container_context_recursively_opt) # {{{ array container traversal helpers @@ -519,7 +521,6 @@ def rec_multimap_reduce_array_container( # {{{ freeze/thaw -@singledispatch def freeze( ary: ArrayOrContainerT, actx: Optional[ArrayContext] = None) -> ArrayOrContainerT: @@ -533,23 +534,33 @@ def freeze( See :meth:`ArrayContext.thaw`. """ - try: - iterable = serialize_container(ary) - except NotAnArrayContainerError: - if actx is None: - raise TypeError( - f"cannot freeze arrays of type {type(ary).__name__} " - "when actx is not supplied. Try calling actx.freeze " - "directly or supplying an array context") - else: - return actx.freeze(ary) + + if actx is None: + warn("Calling freeze(ary) without specifying actx is deprecated, explicitly" + " call actx.freeze(ary) instead. This will stop working in 2023.", + DeprecationWarning, stacklevel=2) + + actx = get_container_context_recursively_opt(ary) else: - return deserialize_container(ary, [ - (key, freeze(subary, actx=actx)) for key, subary in iterable - ]) + warn("Calling freeze(ary, actx) is deprecated, call actx.freeze(ary)" + " instead. This will stop working in 2023.", + DeprecationWarning, stacklevel=2) + + if __debug__: + rec_actx = get_container_context_recursively_opt(ary) + if (rec_actx is not None) and (rec_actx is not actx): + raise ValueError("Supplied array context does not agree with" + " the one obtained by traversing 'ary'.") + + if actx is None: + raise TypeError( + f"cannot freeze arrays of type {type(ary).__name__} " + "when actx is not supplied. Try calling actx.freeze " + "directly or supplying an array context") + + return actx.freeze(ary) -@singledispatch def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: r"""Thaws recursively by going through all components of the :class:`ArrayContainer` *ary*. @@ -570,14 +581,41 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: in :mod:`meshmode`. This was necessary because :func:`~functools.singledispatch` only dispatches on the first argument. """ + warn("Calling thaw(ary, actx) is deprecated, call actx.thaw(ary) instead." + " This will stop working in 2023.", + DeprecationWarning, stacklevel=2) + + if __debug__: + rec_actx = get_container_context_recursively_opt(ary) + if rec_actx is not None: + raise ValueError("cannot thaw a container that already has an array" + " context.") + + return actx.thaw(ary) + +# }}} + + +# {{{ with_array_context + +@singledispatch +def with_array_context(ary: ArrayOrContainerT, + actx: Optional[ArrayContext]) -> ArrayOrContainerT: + """ + Recursively associates *actx* to all the components of *ary*. + + Array container types may use :func:`functools.singledispatch` ``.register`` + to register container-specific implementations. See `this issue + `__ for discussion of + the future of this functionality. + """ try: iterable = serialize_container(ary) except NotAnArrayContainerError: - return actx.thaw(ary) + return ary else: - return deserialize_container(ary, [ - (key, thaw(subary, actx)) for key, subary in iterable - ]) + return deserialize_container(ary, [(key, with_array_context(subary, actx)) + for key, subary in iterable]) # }}} diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 72c9318c9be7f5b7d298a1fb4e47015fc87791a5..154300a665f461c973a709dabf1ad874474c51de 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -32,6 +32,8 @@ import numpy as np from typing import Union, Callable, Any from pytools.tag import ToTagSetConvertible from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.container.traversal import (with_array_context, + rec_map_array_container) class EagerJAXArrayContext(ArrayContext): @@ -84,10 +86,21 @@ class EagerJAXArrayContext(ArrayContext): " operations using ArrayContext.np.") def freeze(self, array): - return array.block_until_ready() + from jax.numpy import DeviceArray + + def _rec_freeze(ary): + if isinstance(ary, DeviceArray): + pass + else: + raise TypeError(f"{type(self).__name__}.thaw expects " + f"`jax.DeviceArray` got {type(ary)}.") + return ary.block_until_ready() + + return with_array_context(rec_map_array_container(_rec_freeze, array), + actx=None) def thaw(self, array): - return array + return with_array_context(array, actx=self) # }}} diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index ab8f568526f8202e9b721a6ba91d7da282b72f6f..de0b27b18ab6d6e8deacafab326bb264d11eeb50 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -36,7 +36,8 @@ import numpy as np from pytools.tag import ToTagSetConvertible from arraycontext.context import ArrayContext, _ScalarLike -from arraycontext.container.traversal import rec_map_array_container +from arraycontext.container.traversal import (rec_map_array_container, + with_array_context) if TYPE_CHECKING: @@ -207,27 +208,41 @@ class PyOpenCLArrayContext(ArrayContext): for name, ary in result.items()} def freeze(self, array): - array.finish() - return array.with_queue(None) + import pyopencl.array as cl_array + + def _rec_freeze(ary): + if isinstance(ary, cl_array.Array): + ary.finish() + return ary.with_queue(None) + else: + raise TypeError(f"{type(self).__name__} cannot freeze" + f" arrays of type '{type(ary).__name__}'.") + + return with_array_context(rec_map_array_container(_rec_freeze, array), + actx=None) def thaw(self, array): from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, to_tagged_cl_array) import pyopencl.array as cl_array - if isinstance(array, TaggableCLArray): - return array.with_queue(self.queue) - elif isinstance(array, cl_array.Array): - from warnings import warn - warn("Invoking PyOpenCLArrayContext.thaw with pyopencl.Array" - " will be unsupported in 2023. Use `to_tagged_cl_array`" - " to convert instances of pyopencl.Array to TaggableCLArray.", - DeprecationWarning, stacklevel=2) - return (to_tagged_cl_array(array, axes=None, tags=frozenset()) - .with_queue(self.queue)) - else: - raise ValueError("array should be a cl.array.Array," - f" got '{type(array)}'") + def _rec_thaw(ary): + if isinstance(ary, TaggableCLArray): + return ary.with_queue(self.queue) + elif isinstance(ary, cl_array.Array): + from warnings import warn + warn("Invoking PyOpenCLArrayContext.thaw with pyopencl.Array" + " will be unsupported in 2023. Use `to_tagged_cl_array`" + " to convert instances of pyopencl.Array to TaggableCLArray.", + DeprecationWarning, stacklevel=2) + return (to_tagged_cl_array(ary, axes=None, tags=frozenset()) + .with_queue(self.queue)) + else: + raise ValueError("array should be a cl.array.Array," + f" got '{type(ary)}'") + + return with_array_context(rec_map_array_container(_rec_thaw, array), + actx=self) # }}} diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 211f0c4ba932901d924a7ca599b010c4c1f020df..e3813b274cb39887d7bf8a80d11530eef6af9045 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -43,11 +43,12 @@ THE SOFTWARE. """ from arraycontext.context import ArrayContext, _ScalarLike -from arraycontext.container.traversal import rec_map_array_container +from arraycontext.container.traversal import (rec_map_array_container, + with_array_context) from arraycontext.metadata import NameHint import numpy as np -from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type, FrozenSet +from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type, FrozenSet, Dict from pytools.tag import ToTagSetConvertible, normalize_tags, Tag import abc @@ -248,45 +249,69 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): import pytato as pt import pyopencl.array as cla import loopy as lp + + from arraycontext.container import ArrayT + from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.utils import (_normalize_pt_expr, get_cl_axes_from_pt_axes) from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, TaggableCLArray) + from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - if isinstance(array, TaggableCLArray): - return array.with_queue(None) - if isinstance(array, cla.Array): - from warnings import warn - warn("Freezing pyopencl.array.Array will be deprecated in 2023." - " Use `to_tagged_cl_array` to convert the array to" - " TaggableCLArray", DeprecationWarning, stacklevel=2) - return to_tagged_cl_array(array.with_queue(None), - axes=None, - tags=frozenset()) - if isinstance(array, pt.DataWrapper): - # trivial freeze. - return to_tagged_cl_array(array.data.with_queue(None), - axes=get_cl_axes_from_pt_axes(array.axes), - tags=array.tags) - if not isinstance(array, pt.Array): - raise TypeError(f"{type(self).__name__}.freeze invoked " - f"with non-pytato array of type '{type(array)}'") - - # {{{ early exit for 0-sized arrays - - if array.size == 0: - return to_tagged_cl_array( - cla.empty(self.queue.context, - shape=array.shape, - dtype=array.dtype, - allocator=self.allocator), - get_cl_axes_from_pt_axes(array.axes), - array.tags) + array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, + pt.Array]] = {} + key_to_frozen_subary: Dict[str, TaggableCLArray] = {} + key_to_pt_arrays: Dict[str, pt.Array] = {} + + def _record_leaf_ary_in_dict(key: Tuple[Any, ...], + ary: ArrayT): + key_str = "_actx" + _ary_container_key_stringifier(key) + array_as_dict[key_str] = ary + return ary + + rec_keyed_map_array_container(_record_leaf_ary_in_dict, array) + + # {{{ remove any non pytato arrays from array_as_dict + + for key, subary in array_as_dict.items(): + if isinstance(subary, TaggableCLArray): + key_to_frozen_subary[key] = subary.with_queue(None) + elif isinstance(subary, cla.Array): + from warnings import warn + warn("Freezing pyopencl.array.Array will be deprecated in 2023." + " Use `to_tagged_cl_array` to convert the array to" + " TaggableCLArray", DeprecationWarning, stacklevel=2) + key_to_frozen_subary[key] = to_tagged_cl_array( + subary.with_queue(None), + axes=None, + tags=frozenset()) + elif isinstance(subary, pt.DataWrapper): + # trivial freeze. + key_to_frozen_subary[key] = to_tagged_cl_array( + subary.data, + axes=get_cl_axes_from_pt_axes(subary.axes), + tags=subary.tags) + else: + if not isinstance(subary, pt.Array): + raise TypeError(f"{type(self).__name__}.freeze invoked " + f"with non-pytato array of type '{type(array)}'") + + if subary.size == 0: + # early exit for 0-sized arrays + key_to_frozen_subary[key] = to_tagged_cl_array( + cla.empty(self.queue.context, + shape=subary.shape, + dtype=subary.dtype, + allocator=self.allocator), + get_cl_axes_from_pt_axes(subary.axes), + subary.tags) + else: + key_to_pt_arrays[key] = subary # }}} pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( - {"_actx_out": array}) + key_to_pt_arrays) normalized_expr, bound_arguments = _normalize_pt_expr( pt_dict_of_named_arrays) @@ -306,16 +331,31 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): cl_device=self.queue.device) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg + else: + transformed_dag = self._dag_transform_cache[normalized_expr] assert len(pt_prg.bound_arguments) == 0 evt, out_dict = pt_prg(self.queue, **bound_arguments) evt.wait() - - return to_tagged_cl_array( - out_dict["_actx_out"].with_queue(None), - get_cl_axes_from_pt_axes( - self._dag_transform_cache[normalized_expr]["_actx_out"].expr.axes), - array.tags) + assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 + + key_to_frozen_subary = { + **key_to_frozen_subary, + **{k: to_tagged_cl_array(v.with_queue(None), + get_cl_axes_from_pt_axes(transformed_dag[k] + .expr + .axes), + transformed_dag[k].expr.tags) + for k, v in out_dict.items()} + } + + def _to_frozen(key: Tuple[Any, ...], ary: ArrayT): + key_str = "_actx" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + return with_array_context(rec_keyed_map_array_container(_to_frozen, + array), + actx=None) def thaw(self, array): import pytato as pt @@ -324,18 +364,21 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): to_tagged_cl_array) import pyopencl.array as cl_array - if isinstance(array, TaggableCLArray): - pass - elif isinstance(array, cl_array.Array): - array = to_tagged_cl_array(array, axes=None, tags=frozenset()) - else: - raise TypeError(f"{type(self).__name__}.thaw expects " - "'TaggableCLArray' or 'cl.array.Array' got " - f"{type(array)}.") + def _rec_thaw(ary): + if isinstance(ary, TaggableCLArray): + pass + elif isinstance(ary, cl_array.Array): + ary = to_tagged_cl_array(ary, axes=None, tags=frozenset()) + else: + raise TypeError(f"{type(self).__name__}.thaw expects " + "'TaggableCLArray' or 'cl.array.Array' got " + f"{type(ary)}.") + return pt.make_data_wrapper(ary.with_queue(self.queue), + axes=get_pt_axes_from_cl_axes(ary.axes), + tags=ary.tags) - return pt.make_data_wrapper(array.with_queue(self.queue), - axes=get_pt_axes_from_cl_axes(array.axes), - tags=array.tags) + return with_array_context(rec_map_array_container(_rec_thaw, array), + actx=self) # }}} @@ -443,41 +486,75 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def freeze(self, array): import pytato as pt + from jax.numpy import DeviceArray + from arraycontext.container import ArrayT + from arraycontext.container.traversal import rec_keyed_map_array_container + from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + + array_as_dict: Dict[str, Union[DeviceArray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, DeviceArray] = {} + key_to_pt_arrays: Dict[str, pt.Array] = {} + + def _record_leaf_ary_in_dict(key: Tuple[Any, ...], + ary: Union[DeviceArray, pt.Array]): + key_str = "_actx" + _ary_container_key_stringifier(key) + array_as_dict[key_str] = ary + return ary - if isinstance(array, DeviceArray): - return array.block_until_ready() - if not isinstance(array, pt.Array): - raise TypeError(f"{type(self)}.freeze invoked with " - f"non-pytato array of type '{type(array)}'") + rec_keyed_map_array_container(_record_leaf_ary_in_dict, array) - from arraycontext.impl.pytato.utils import _normalize_pt_expr - pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( - {"_actx_out": array}) + # {{{ remove any non pytato arrays from array_as_dict - normalized_expr, bound_arguments = _normalize_pt_expr( - pt_dict_of_named_arrays) + for key, subary in array_as_dict.items(): + if isinstance(subary, DeviceArray): + key_to_frozen_subary[key] = subary.block_until_ready() + elif isinstance(subary, pt.DataWrapper): + # trivial freeze. + key_to_frozen_subary[key] = subary.data.block_until_ready() + else: + if not isinstance(subary, pt.Array): + raise TypeError(f"{type(self).__name__}.freeze invoked " + f"with non-pytato array of type '{type(array)}'") - try: - pt_prg = self._freeze_prg_cache[normalized_expr] - except KeyError: - pt_prg = pt.generate_jax(self.transform_dag(normalized_expr), - jit=True) - self._freeze_prg_cache[normalized_expr] = pt_prg + key_to_pt_arrays[key] = subary - assert len(pt_prg.bound_arguments) == 0 - out_dict = pt_prg(**bound_arguments) + # }}} + + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) + transformed_dag = self.transform_dag(pt_dict_of_named_arrays) + pt_prg = pt.generate_jax(transformed_dag, jit=True) + out_dict = pt_prg() + assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 + + key_to_frozen_subary = { + **key_to_frozen_subary, + **{k: v.block_until_ready() + for k, v in out_dict.items()} + } + + def _to_frozen(key: Tuple[Any, ...], ary: ArrayT): + key_str = "_actx" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] - return out_dict["_actx_out"].block_until_ready() + return with_array_context(rec_keyed_map_array_container(_to_frozen, + array), + actx=None) def thaw(self, array): import pytato as pt + from jax.numpy import DeviceArray - if not isinstance(array, self.frozen_array_types): - raise TypeError(f"{type(self)}.thaw expects jax device arrays, got " - f"{type(array)}") + def _rec_thaw(ary): + if isinstance(ary, DeviceArray): + pass + else: + raise TypeError(f"{type(self).__name__}.thaw expects " + f"'jax.DeviceArray' got {type(ary)}.") + return pt.make_data_wrapper(ary) - return pt.make_data_wrapper(array) + return with_array_context(rec_map_array_container(_rec_thaw, array), + actx=self) def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyJAXCompilingFunctionCaller