diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7592520ccdf98b8fa6cc8e05c743c9969dea854..730133e25104f8824322e1bd8b17cbbe2bd01fab 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -79,6 +79,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): self.allocator = allocator self.array_types = (pt.Array, ) self._freeze_prg_cache = {} + self._dag_transform_cache = {} # unused, but necessary to keep the context alive self.context = self.queue.context @@ -113,24 +114,56 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return cl_array.get(queue=self.queue) def call_loopy(self, program, **kwargs): - import pyopencl.array as cla + from pytato.scalar_expr import SCALAR_CLASSES from pytato.loopy import call_loopy + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray entrypoint = program.default_entrypoint.name - # thaw frozen arrays - kwargs = {kw: (self.thaw(arg) if isinstance(arg, cla.Array) else arg) - for kw, arg in kwargs.items()} + # {{{ preprocess args + + processed_kwargs = {} + + for kw, arg in sorted(kwargs.items()): + if isinstance(arg, self.array_types + SCALAR_CLASSES): + pass + elif isinstance(arg, TaggableCLArray): + arg = self.thaw(arg) + else: + raise ValueError(f"call_loopy argument '{kw}' expected to be an" + " instance of 'pytato.Array', 'Number' or" + f"'TaggableCLArray', got '{type(arg)}'") + + processed_kwargs[kw] = arg + + # }}} - return call_loopy(program, kwargs, entrypoint) + return call_loopy(program, processed_kwargs, entrypoint) def freeze(self, array): import pytato as pt import pyopencl.array as cla import loopy as lp + from arraycontext.impl.pytato.utils import (_normalize_pt_expr, + get_cl_axes_from_pt_axes) + from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, + TaggableCLArray) - if isinstance(array, cla.Array): + if isinstance(array, TaggableCLArray): return array.with_queue(None) + if isinstance(array, cla.Array): + from warnings import warn + warn("Freezing pyopencl.array.Array will be deprecated in 2023." + " Use `to_tagged_cl_array` to convert the array to" + " TaggableCLArray", DeprecationWarning, stacklevel=2) + return to_tagged_cl_array(array.with_queue(None), + axes=None, + tags=frozenset()) + if isinstance(array, pt.DataWrapper): + # trivial freeze. + return to_tagged_cl_array(array.data.with_queue(None), + axes=get_cl_axes_from_pt_axes(array.axes), + tags=array.tags) if not isinstance(array, pt.Array): raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " f"non-pytato array of type '{type(array)}'") @@ -138,14 +171,16 @@ class PytatoPyOpenCLArrayContext(ArrayContext): # {{{ early exit for 0-sized arrays if array.size == 0: - return cla.empty(self.queue.context, - shape=array.shape, - dtype=array.dtype, - allocator=self.allocator) + return to_tagged_cl_array( + cla.empty(self.queue.context, + shape=array.shape, + dtype=array.dtype, + allocator=self.allocator), + get_cl_axes_from_pt_axes(array.axes), + array.tags) # }}} - from arraycontext.impl.pytato.utils import _normalize_pt_expr pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( {"_actx_out": array}) @@ -155,7 +190,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext): try: pt_prg = self._freeze_prg_cache[normalized_expr] except KeyError: - pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr), + if normalized_expr in self._dag_transform_cache: + transformed_dag = self._dag_transform_cache[normalized_expr] + else: + transformed_dag = self.transform_dag(normalized_expr) + self._dag_transform_cache[normalized_expr] = transformed_dag + + pt_prg = pt.generate_loopy(transformed_dag, options=lp.Options(return_dict=True, no_numpy=True), cl_device=self.queue.device) @@ -166,17 +207,31 @@ class PytatoPyOpenCLArrayContext(ArrayContext): evt, out_dict = pt_prg(self.queue, **bound_arguments) evt.wait() - return out_dict["_actx_out"].with_queue(None) + return to_tagged_cl_array( + out_dict["_actx_out"].with_queue(None), + get_cl_axes_from_pt_axes( + self._dag_transform_cache[normalized_expr]["_actx_out"].expr.axes), + array.tags) def thaw(self, array): import pytato as pt - import pyopencl.array as cla - - if not isinstance(array, cla.Array): - raise TypeError("PytatoPyOpenCLArrayContext.thaw expects CL arrays, got " - f"{type(array)}") - - return pt.make_data_wrapper(array.with_queue(self.queue)) + from .utils import get_pt_axes_from_cl_axes + from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, + to_tagged_cl_array) + import pyopencl.array as cl_array + + if isinstance(array, TaggableCLArray): + pass + elif isinstance(array, cl_array.Array): + array = to_tagged_cl_array(array, axes=None, tags=frozenset()) + else: + raise TypeError("PytatoPyOpenCLArrayContext.thaw expects " + "'TaggableCLArray' or 'cl.array.Array' got " + f"{type(array)}.") + + return pt.make_data_wrapper(array.with_queue(self.queue), + axes=get_pt_axes_from_cl_axes(array.axes), + tags=array.tags) # }}} @@ -219,12 +274,23 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def einsum(self, spec, *args, arg_names=None, tagged=()): import pyopencl.array as cla import pytato as pt + from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, + to_tagged_cl_array) if arg_names is None: arg_names = (None,) * len(args) def preprocess_arg(name, arg): - if isinstance(arg, cla.Array): + if isinstance(arg, TaggableCLArray): ary = self.thaw(arg) + elif isinstance(arg, cla.Array): + from warnings import warn + warn("Passing pyopencl.array.Array to einsum will be " + "deprecated in 2023." + " Use `to_tagged_cl_array` to convert the array to" + " TaggableCLArray.", DeprecationWarning, stacklevel=2) + ary = self.thaw(to_tagged_cl_array(arg, + axes=None, + tags=frozenset())) else: assert isinstance(arg, pt.Array) ary = arg diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 71f98a8c5c6ffd5dbb3027082b8cd87d0b6fe146..d83e376a5c10f7f8ea82a7f4547978308af157f6 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -34,7 +34,7 @@ from arraycontext.container.traversal import rec_keyed_map_array_container import abc import numpy as np -from typing import Any, Callable, Tuple, Dict, Mapping +from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet from dataclasses import dataclass, field from pyrsistent import pmap, PMap @@ -169,7 +169,11 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name): elif is_array_container_type(arg.__class__): def _rec_to_placeholder(keys, ary): name = arg_id_to_name[(kw,) + keys] - return pt.make_placeholder(name, ary.shape, ary.dtype) + return pt.make_placeholder(name, + ary.shape, + ary.dtype, + axes=ary.axes, + tags=ary.tags) return rec_keyed_map_array_container(_rec_to_placeholder, arg) else: @@ -204,6 +208,13 @@ class LazilyCompilingFunctionCaller: with ProcessLogger(logger, "transform_dag"): pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) + name_in_program_to_tags = { + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_axes = { + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()} + with ProcessLogger(logger, "generate_loopy"): pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, options=lp.Options( @@ -225,7 +236,7 @@ class LazilyCompilingFunctionCaller: .actx .transform_loopy_program)) - return pytato_program + return pytato_program, name_in_program_to_tags, name_in_program_to_axes def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, input_id_to_name_in_program, output_id_to_name_in_program, @@ -234,18 +245,23 @@ class LazilyCompilingFunctionCaller: output_id = "_pt_out" dict_of_named_arrays = pt.make_dict_of_named_arrays( {output_id: ary_or_dict_of_named_arrays}) - pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays) + pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( + self._dag_to_transformed_loopy_prg(dict_of_named_arrays)) return CompiledFunctionReturningArray( self.actx, pytato_program, input_id_to_name_in_program=input_id_to_name_in_program, - output_name_in_program=output_id) + output_tags=name_in_program_to_tags[output_id], + output_axes=name_in_program_to_axes[output_id], + output_name=output_id) elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): - pytato_program = self._dag_to_transformed_loopy_prg( - ary_or_dict_of_named_arrays) + pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( + self._dag_to_transformed_loopy_prg(ary_or_dict_of_named_arrays)) return CompiledFunctionReturningArrayContainer( self.actx, pytato_program, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, + name_in_program_to_tags=name_in_program_to_tags, + name_in_program_to_axes=name_in_program_to_axes, output_template=output_template) else: raise NotImplementedError(type(ary_or_dict_of_named_arrays)) @@ -312,6 +328,8 @@ class LazilyCompilingFunctionCaller: def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): @@ -320,7 +338,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): elif isinstance(arg, pt.array.DataWrapper): # got a Datwwrapper => simply gets its data arg = arg.data - elif isinstance(arg, cla.Array): + elif isinstance(arg, TaggableCLArray): # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): @@ -383,9 +401,14 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] + name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + from .utils import get_cl_axes_from_pt_axes + input_kwargs_for_loopy = _args_to_cl_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -399,7 +422,12 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): evt.wait() def to_output_template(keys, _): - return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]]) + name_in_program = self.output_id_to_name_in_program[keys] + return self.actx.thaw(to_tagged_cl_array( + out_dict[name_in_program], + axes=get_cl_axes_from_pt_axes( + self.name_in_program_to_axes[name_in_program]), + tags=self.name_in_program_to_tags[name_in_program])) return rec_keyed_map_array_container(to_output_template, self.output_template) @@ -415,9 +443,14 @@ class CompiledFunctionReturningArray(CompiledFunction): actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_tags: FrozenSet[Tag] + output_axes: Tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + from .utils import get_cl_axes_from_pt_axes + input_kwargs_for_loopy = _args_to_cl_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -430,4 +463,7 @@ class CompiledFunctionReturningArray(CompiledFunction): # running out of memory. This mitigates that risk a bit, for now. evt.wait() - return self.actx.thaw(out_dict[self.output_name]) + return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name], + axes=get_cl_axes_from_pt_axes( + self.output_axes), + tags=self.output_tags)) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index f14d166e48bc5e3f1194a74824e39b8d37da0c4e..2babd559856c2d3518301562fd48035297ae4641 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -58,6 +58,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, + axes=expr.axes, tags=expr.tags) def map_size_param(self, expr: SizeParam) -> Array: