diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 7336aa9b83cfc3eb524c8da682c6f898cd06abe2..4e6cc384d726aee358e9ea08c3f16350805bdf3e 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -1,6 +1,7 @@ """ .. currentmodule:: arraycontext .. autoclass:: PyOpenCLArrayContext +.. automodule:: arraycontext.impl.pyopencl.taggable_cl_array """ __copyright__ = """ @@ -35,6 +36,7 @@ import numpy as np from pytools.tag import Tag from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.container.traversal import rec_map_array_container if TYPE_CHECKING: @@ -65,6 +67,8 @@ class PyOpenCLArrayContext(ArrayContext): of arrays are created (e.g. as results of computation), the associated cost may become significant. Using e.g. :class:`pyopencl.tools.MemoryPool` as the allocator can help avoid this cost. + + .. automethod:: transform_loopy_program """ def __init__(self, @@ -109,7 +113,7 @@ class PyOpenCLArrayContext(ArrayContext): DeprecationWarning, stacklevel=2) import pyopencl as cl - import pyopencl.array as cla + import pyopencl.array as cl_array super().__init__() self.context = queue.context @@ -138,7 +142,9 @@ class PyOpenCLArrayContext(ArrayContext): self._loopy_transform_cache: \ Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} - self.array_types = (cla.Array,) + # TODO: Ideally this should only be `(TaggableCLArray,)`, but + # that would break the logic in the downstream users. + self.array_types = (cl_array.Array,) def _get_fake_numpy_namespace(self): from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace @@ -147,18 +153,27 @@ class PyOpenCLArrayContext(ArrayContext): # {{{ ArrayContext interface def empty(self, shape, dtype): - import pyopencl.array as cl_array - return cl_array.empty(self.queue, shape=shape, dtype=dtype, - allocator=self.allocator) + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + + return TaggableCLArray(self.queue, shape=shape, dtype=dtype, + allocator=self.allocator) def zeros(self, shape, dtype): import pyopencl.array as cl_array - return cl_array.zeros(self.queue, shape=shape, dtype=dtype, - allocator=self.allocator) + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + return to_tagged_cl_array(cl_array.zeros(self.queue, shape=shape, + dtype=dtype, + allocator=self.allocator), + axes=None, tags=frozenset()) def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): import pyopencl.array as cl_array - return cl_array.to_device(self.queue, array, allocator=self.allocator) + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + return to_tagged_cl_array(cl_array + .to_device(self.queue, + array, + allocator=self.allocator), + axes=None, tags=frozenset()) def to_numpy(self, array): if np.isscalar(array): @@ -186,14 +201,33 @@ class PyOpenCLArrayContext(ArrayContext): if len(wait_event_queue) > self._wait_event_queue_length: wait_event_queue.pop(0).wait() - return result + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + # FIXME: Inherit loopy tags for these arrays + return {name: to_tagged_cl_array(ary, axes=None, tags=frozenset()) + for name, ary in result.items()} def freeze(self, array): array.finish() return array.with_queue(None) def thaw(self, array): - return array.with_queue(self.queue) + from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, + to_tagged_cl_array) + import pyopencl.array as cl_array + + if isinstance(array, TaggableCLArray): + return array.with_queue(self.queue) + elif isinstance(array, cl_array.Array): + from warnings import warn + warn("Invoking PyOpenCLArrayContext.thaw with pyopencl.Array" + " will be unsupported in 2023. Use `to_tagged_cl_array`" + " to convert instances of pyopencl.Array to TaggableCLArray.", + DeprecationWarning, stacklevel=2) + return (to_tagged_cl_array(array, axes=None, tags=frozenset()) + .with_queue(self.queue)) + else: + raise ValueError("array should be a cl.array.Array," + f" got '{type(array)}'") # }}} @@ -268,12 +302,37 @@ class PyOpenCLArrayContext(ArrayContext): return t_unit def tag(self, tags: Union[Sequence[Tag], Tag], array): - # Sorry, not capable. - return array + import pyopencl.array as cl_array + from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, + to_tagged_cl_array) + + def _rec_tagged(ary): + if isinstance(ary, TaggableCLArray): + return ary.tagged(tags) + elif isinstance(ary, cl_array.Array): + return to_tagged_cl_array(ary, axes=None, tags=tags) + else: + raise ValueError("array should be a cl.array.Array," + f" got '{type(ary)}'") + + return rec_map_array_container(_rec_tagged, array) def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): - # Sorry, not capable. - return array + import pyopencl.array as cl_array + from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, + to_tagged_cl_array) + + def _rec_tagged(ary): + if isinstance(ary, TaggableCLArray): + return ary.with_tagged_axis(iaxis, tags) + elif isinstance(ary, cl_array.Array): + return (to_tagged_cl_array(ary, axes=None, tags=tags) + .with_tagged_axis(iaxis, tags)) + else: + raise ValueError("array should be a cl.array.Array," + f" got '{type(ary)}'") + + return rec_map_array_container(_rec_tagged, array) def clone(self): return type(self)(self.queue, self.allocator, diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py new file mode 100644 index 0000000000000000000000000000000000000000..ecde87d14cf19c4cce6761a7623991318879d971 --- /dev/null +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -0,0 +1,129 @@ +""" +.. autoclass:: TaggableCLArray +.. autoclass:: Axis + +.. autofunction:: to_tagged_cl_array +""" + +import pyopencl.array as cla +from typing import FrozenSet, Union, Sequence, Optional, Tuple +from pytools.tag import Taggable, Tag +from dataclasses import dataclass +from pytools import memoize + + +@dataclass(frozen=True, eq=True) +class Axis(Taggable): + """ + Records the tags corresponding to a dimensions of :class:`TaggableCLArray`. + """ + tags: FrozenSet[Tag] + + def copy(self, **kwargs): + from dataclasses import replace + return replace(self, **kwargs) + + +@memoize +def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]: + return tuple(Axis(frozenset()) for _ in range(ndim)) + + +class TaggableCLArray(cla.Array, Taggable): + """ + A :class:`pyopencl.array.Array` with additional metadata. This is used by + :class:`~arraycontext.PytatoPyOpenCLArrayContext` to preserve tags for data + while frozen, and also in a similar capacity by + :class:`~arraycontext.PyOpenCLArrayContext`. + + .. attribute:: axes + + A :class:`tuple` of instances of :class:`Axis`, with one :class:`Axis` + for each dimension of the array. + + .. attribute:: tags + + A :class:`frozenset` of :class:`pytools.tag.Tag`. Typically intended to + record application-specific metadata to drive the optimizations in + :meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`. + """ + def __init__(self, cq, shape, dtype, order="C", allocator=None, + data=None, offset=0, strides=None, events=None, _flags=None, + _fast=False, _size=None, _context=None, _queue=None, + axes=None, tags=frozenset()): + + super().__init__(cq=cq, shape=shape, dtype=dtype, + order=order, allocator=allocator, + data=data, offset=offset, + strides=strides, events=events, + _flags=_flags, _fast=_fast, + _size=_size, _context=_context, + _queue=_queue) + + self.tags = tags + axes = axes if axes is not None else _construct_untagged_axes(len(self + .shape)) + self.axes = axes + + def copy(self, queue=cla._copy_queue, tags=None, axes=None, _new_class=None): + """ + :arg _new_class: The class of the copy. :func:`to_tagged_cl_array` is + sets this to convert instances of :class:`pyopencl.array.Array` to + :class:`TaggableCLArray`. If not provided, defaults to + ``self.__class__``. + """ + _new_class = self.__class__ if _new_class is None else _new_class + + if queue is not cla._copy_queue: + # Copying command queue is an involved operation, use super-class' + # implementation. + base_instance = super().copy(queue=queue) + else: + base_instance = self + + if tags is None and axes is None and _new_class is self.__class__: + # early exit + return base_instance + + tags = getattr(base_instance, "tags", frozenset()) if tags is None else tags + axes = getattr(base_instance, "axes", None) if axes is None else axes + + return _new_class(None, + base_instance.shape, + base_instance.dtype, + allocator=base_instance.allocator, + strides=base_instance.strides, + data=base_instance.base_data, + offset=base_instance.offset, + events=base_instance.events, _fast=True, + _context=base_instance.context, + _queue=base_instance.queue, + _size=base_instance.size, + tags=tags, + axes=axes, + ) + + def with_tagged_axis(self, iaxis: int, + tags: Union[Sequence[Tag], Tag]) -> "TaggableCLArray": + """ + Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. + """ + new_axes = (self.axes[:iaxis] + + (self.axes[iaxis].tagged(tags),) + + self.axes[iaxis+1:]) + return self.copy(axes=new_axes) + + +def to_tagged_cl_array(ary: cla.Array, + axes: Optional[Tuple[Axis, ...]], + tags: FrozenSet[Tag]) -> TaggableCLArray: + """ + Returns a :class:`TaggableCLArray` that is constructed from the data in + *ary* along with the metadata from *axes* and *tags*. + + :arg axes: An instance of :class:`Axis` for each dimension of the + array. If passed *None*, then initialized to a :class:`pytato.Axis` + with no tags attached for each dimension. + """ + return TaggableCLArray.copy(ary, axes=axes, tags=tags, + _new_class=TaggableCLArray) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index f69a1961e64c9f7a93a30841874b1680299c7032..f14d166e48bc5e3f1194a74824e39b8d37da0c4e 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -24,10 +24,11 @@ THE SOFTWARE. from typing import Any, Dict, Set, Tuple, Mapping -from pytato.array import SizeParam, Placeholder, make_placeholder +from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis from pytato.array import Array, DataWrapper, DictOfNamedArrays from pytato.transform import CopyMapper from pytools import UniqueNameGenerator +from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis class _DatawrapperToBoundPlaceholderMapper(CopyMapper): @@ -81,3 +82,11 @@ def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalized_expr = normalize_mapper(expr) return normalized_expr, normalize_mapper.bound_arguments + + +def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]: + return tuple(PtAxis(axis.tags) for axis in axes) + + +def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]: + return tuple(ClAxis(axis.tags) for axis in axes)