diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index ecde87d14cf19c4cce6761a7623991318879d971..d34344788695e783ff5f5baa6f3dd2388e8f2115 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -6,8 +6,8 @@ """ import pyopencl.array as cla -from typing import FrozenSet, Union, Sequence, Optional, Tuple -from pytools.tag import Taggable, Tag +from typing import Any, Dict, FrozenSet, Optional, Tuple +from pytools.tag import Taggable, Tag, TagsType, TagOrIterableType from dataclasses import dataclass from pytools import memoize @@ -29,6 +29,20 @@ def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]: return tuple(Axis(frozenset()) for _ in range(ndim)) +def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]: + return dict(shape=ary.shape, dtype=ary.dtype, + allocator=ary.allocator, + strides=ary.strides, + data=ary.base_data, + offset=ary.offset, + events=ary.events, + _context=ary.context, + _queue=ary.queue, + _size=ary.size, + _fast=True, + ) + + class TaggableCLArray(cla.Array, Taggable): """ A :class:`pyopencl.array.Array` with additional metadata. This is used by @@ -60,58 +74,40 @@ class TaggableCLArray(cla.Array, Taggable): _size=_size, _context=_context, _queue=_queue) + if __debug__: + if not isinstance(tags, frozenset): + raise TypeError("tags are not a frozenset") + + if axes is not None and len(axes) != self.ndim: + raise ValueError("axes length does not match array dimension: " + f"got {len(axes)} axes for {self.ndim}d array") + + if axes is None: + axes = _construct_untagged_axes(self.ndim) + self.tags = tags - axes = axes if axes is not None else _construct_untagged_axes(len(self - .shape)) self.axes = axes - def copy(self, queue=cla._copy_queue, tags=None, axes=None, _new_class=None): - """ - :arg _new_class: The class of the copy. :func:`to_tagged_cl_array` is - sets this to convert instances of :class:`pyopencl.array.Array` to - :class:`TaggableCLArray`. If not provided, defaults to - ``self.__class__``. - """ - _new_class = self.__class__ if _new_class is None else _new_class - - if queue is not cla._copy_queue: - # Copying command queue is an involved operation, use super-class' - # implementation. - base_instance = super().copy(queue=queue) - else: - base_instance = self - - if tags is None and axes is None and _new_class is self.__class__: - # early exit - return base_instance - - tags = getattr(base_instance, "tags", frozenset()) if tags is None else tags - axes = getattr(base_instance, "axes", None) if axes is None else axes - - return _new_class(None, - base_instance.shape, - base_instance.dtype, - allocator=base_instance.allocator, - strides=base_instance.strides, - data=base_instance.base_data, - offset=base_instance.offset, - events=base_instance.events, _fast=True, - _context=base_instance.context, - _queue=base_instance.queue, - _size=base_instance.size, - tags=tags, - axes=axes, - ) + def copy(self, queue=cla._copy_queue): + ary = super().copy(queue=queue) + return type(self)(None, tags=self.tags, axes=self.axes, + **_unwrap_cl_array(ary)) + + def _with_new_tags(self, tags: TagsType) -> "TaggableCLArray": + return type(self)(None, tags=tags, axes=self.axes, + **_unwrap_cl_array(self)) def with_tagged_axis(self, iaxis: int, - tags: Union[Sequence[Tag], Tag]) -> "TaggableCLArray": + tags: TagOrIterableType) -> "TaggableCLArray": """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ new_axes = (self.axes[:iaxis] + (self.axes[iaxis].tagged(tags),) + self.axes[iaxis+1:]) - return self.copy(axes=new_axes) + + return type(self)(None, tags=self.tags, axes=new_axes, + **_unwrap_cl_array(self)) def to_tagged_cl_array(ary: cla.Array, @@ -119,11 +115,32 @@ def to_tagged_cl_array(ary: cla.Array, tags: FrozenSet[Tag]) -> TaggableCLArray: """ Returns a :class:`TaggableCLArray` that is constructed from the data in - *ary* along with the metadata from *axes* and *tags*. + *ary* along with the metadata from *axes* and *tags*. If *ary* is already a + :class:`TaggableCLArray`, the new *tags* and *axes* are added to the + existing ones. :arg axes: An instance of :class:`Axis` for each dimension of the array. If passed *None*, then initialized to a :class:`pytato.Axis` with no tags attached for each dimension. """ - return TaggableCLArray.copy(ary, axes=axes, tags=tags, - _new_class=TaggableCLArray) + if axes and len(axes) != ary.ndim: + raise ValueError("axes length does not match array dimension: " + f"got {len(axes)} axes for {ary.ndim}d array") + + from pytools.tag import normalize_tags + tags = normalize_tags(tags) + + if isinstance(ary, TaggableCLArray): + if axes: + for i, axis in enumerate(axes): + ary = ary.with_tagged_axis(i, axis.tags) + + if tags: + ary = ary.tagged(tags) + + return ary + elif isinstance(ary, cla.Array): + return TaggableCLArray(None, tags=tags, axes=axes, + **_unwrap_cl_array(ary)) + else: + raise TypeError(f"unsupported array type: '{type(ary).__name__}'") diff --git a/setup.py b/setup.py index 62ff4a7bae995e792efb219c3febb77e9f6f0066..8b0d677b49f2fc633a4ce3b1da8d8f8a146df57e 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,10 @@ def main(): python_requires="~=3.6", install_requires=[ "numpy", - "pytools>=2020.4.1", + + # https://github.com/inducer/arraycontext/pull/147 + "pytools>=2022.1.1", + "pytest>=2.3", "loopy>=2019.1", "dataclasses; python_version<'3.7'", diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 92195f5aabb19796944e86e1e37410c6afc5b100..54acefceafe5a00b989e3cdcc0af8e1a6ae9b7ac 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1415,6 +1415,8 @@ def test_array_container_with_numpy(actx_factory): # }}} +# {{{ test_actx_compile_on_pure_array_return + def test_actx_compile_on_pure_array_return(actx_factory): def _twice(x): return 2 * x @@ -1424,6 +1426,55 @@ def test_actx_compile_on_pure_array_return(actx_factory): np.testing.assert_allclose(actx.to_numpy(_twice(ones)), actx.to_numpy(actx.compile(_twice)(ones))) +# }}} + + +# {{{ + +def test_taggable_cl_array_tags(actx_factory): + actx = actx_factory() + if not isinstance(actx, PyOpenCLArrayContext): + pytest.skip(f"not relevant for '{type(actx).__name__}'") + + import pyopencl.array as cl_array + ary = cl_array.to_device(actx.queue, np.zeros((32, 7))) + + # {{{ check tags are set + + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + tagged_ary = to_tagged_cl_array(ary, axes=None, + tags=frozenset((FirstAxisIsElementsTag(),))) + + assert tagged_ary.base_data is ary.base_data + assert tagged_ary.tags == frozenset((FirstAxisIsElementsTag(),)) + + # }}} + + # {{{ check tags are appended + + from arraycontext import ElementwiseMapKernelTag + tagged_ary = to_tagged_cl_array(tagged_ary, axes=None, + tags=frozenset((ElementwiseMapKernelTag(),))) + + assert tagged_ary.base_data is ary.base_data + assert tagged_ary.tags == frozenset( + (FirstAxisIsElementsTag(), ElementwiseMapKernelTag()) + ) + + # }}} + + # {{{ test copied tags + + copy_tagged_ary = tagged_ary.copy() + + assert copy_tagged_ary.tags == tagged_ary.tags + assert copy_tagged_ary.axes == tagged_ary.axes + assert copy_tagged_ary.base_data != tagged_ary.base_data + + # }}} + +# }}} + if __name__ == "__main__": import sys