From d757876d875afe07c430b9602cb25f8e397c54a0 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Fri, 24 Jun 2022 20:58:46 +0300 Subject: [PATCH] add some helper functions to taggable_cl_array --- arraycontext/impl/pyopencl/__init__.py | 75 +++++++------------ .../impl/pyopencl/taggable_cl_array.py | 65 ++++++++++++++-- 2 files changed, 86 insertions(+), 54 deletions(-) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 1467246..b1a086d 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -154,27 +154,16 @@ class PyOpenCLArrayContext(ArrayContext): # {{{ ArrayContext interface def empty(self, shape, dtype): - from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray - - return TaggableCLArray(self.queue, shape=shape, dtype=dtype, - allocator=self.allocator) + import arraycontext.impl.pyopencl.taggable_cl_array as tga + return tga.empty(self.queue, shape, dtype, allocator=self.allocator) def zeros(self, shape, dtype): - import pyopencl.array as cl_array - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - return to_tagged_cl_array(cl_array.zeros(self.queue, shape=shape, - dtype=dtype, - allocator=self.allocator), - axes=None, tags=frozenset()) + import arraycontext.impl.pyopencl.taggable_cl_array as tga + return tga.zeros(self.queue, shape, dtype, allocator=self.allocator) def from_numpy(self, array: Union[np.ndarray, ScalarLike]): - import pyopencl.array as cl_array - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - return to_tagged_cl_array(cl_array - .to_device(self.queue, - array, - allocator=self.allocator), - axes=None, tags=frozenset()) + import arraycontext.impl.pyopencl.taggable_cl_array as tga + return tga.to_device(self.queue, array, allocator=self.allocator) def to_numpy(self, array): if np.isscalar(array): @@ -202,10 +191,9 @@ class PyOpenCLArrayContext(ArrayContext): if len(wait_event_queue) > self._wait_event_queue_length: wait_event_queue.pop(0).wait() - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + import arraycontext.impl.pyopencl.taggable_cl_array as tga # FIXME: Inherit loopy tags for these arrays - return {name: to_tagged_cl_array(ary, axes=None, tags=frozenset()) - for name, ary in result.items()} + return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()} def freeze(self, array): import pyopencl.array as cl_array @@ -222,24 +210,21 @@ class PyOpenCLArrayContext(ArrayContext): actx=None) def thaw(self, array): - from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, - to_tagged_cl_array) - import pyopencl.array as cl_array + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _rec_thaw(ary): - if isinstance(ary, TaggableCLArray): + if isinstance(ary, tga.TaggableCLArray): return ary.with_queue(self.queue) - elif isinstance(ary, cl_array.Array): + elif isinstance(ary, self.array_types): from warnings import warn - warn("Invoking PyOpenCLArrayContext.thaw with pyopencl.Array" + warn(f"Invoking PyOpenCLArrayContext.thaw with {type(ary).__name__}" " will be unsupported in 2023. Use `to_tagged_cl_array`" - " to convert instances of pyopencl.Array to TaggableCLArray.", + " to convert instances to TaggableCLArray.", DeprecationWarning, stacklevel=2) - return (to_tagged_cl_array(ary, axes=None, tags=frozenset()) - .with_queue(self.queue)) + return (tga.to_tagged_cl_array(ary).with_queue(self.queue)) else: - raise ValueError("array should be a cl.array.Array," - f" got '{type(ary)}'") + raise ValueError( + f"array should be a cl.array.Array, got '{type(ary).__name__}'") return with_array_context(rec_map_array_container(_rec_thaw, array), actx=self) @@ -317,35 +302,31 @@ class PyOpenCLArrayContext(ArrayContext): return t_unit def tag(self, tags: ToTagSetConvertible, array): - import pyopencl.array as cl_array - from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, - to_tagged_cl_array) + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _rec_tagged(ary): - if isinstance(ary, TaggableCLArray): + if isinstance(ary, tga.TaggableCLArray): return ary.tagged(tags) - elif isinstance(ary, cl_array.Array): - return to_tagged_cl_array(ary, axes=None, tags=tags) + elif isinstance(ary, self.array_types): + return tga.to_tagged_cl_array(ary, tags=tags) else: - raise ValueError("array should be a cl.array.Array," - f" got '{type(ary)}'") + raise ValueError( + f"array should be a cl.array.Array, got '{type(ary).__name__}'") return rec_map_array_container(_rec_tagged, array) def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): - import pyopencl.array as cl_array - from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray, - to_tagged_cl_array) + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _rec_tagged(ary): - if isinstance(ary, TaggableCLArray): + if isinstance(ary, tga.TaggableCLArray): return ary.with_tagged_axis(iaxis, tags) - elif isinstance(ary, cl_array.Array): - return (to_tagged_cl_array(ary, axes=None, tags=tags) + elif isinstance(ary, self.array_types): + return (tga.to_tagged_cl_array(ary, tags=tags) .with_tagged_axis(iaxis, tags)) else: - raise ValueError("array should be a cl.array.Array," - f" got '{type(ary)}'") + raise ValueError( + f"array should be a cl.array.Array, got '{type(ary).__name__}'") return rec_map_array_container(_rec_tagged, array) diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 439ca58..49ae08b 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -5,12 +5,17 @@ .. autofunction:: to_tagged_cl_array """ -import pyopencl.array as cla -from typing import Any, Dict, FrozenSet, Optional, Tuple -from pytools.tag import Taggable, Tag, ToTagSetConvertible from dataclasses import dataclass +from typing import Any, Dict, FrozenSet, Optional, Tuple + +import numpy as np +import pyopencl.array as cla + from pytools import memoize +from pytools.tag import Taggable, Tag, ToTagSetConvertible + +# {{{ utils @dataclass(frozen=True, eq=True) class Axis(Taggable): @@ -42,6 +47,10 @@ def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]: _fast=True, ) +# }}} + + +# {{{ TaggableCLArray class TaggableCLArray(cla.Array, Taggable): """ @@ -111,8 +120,8 @@ class TaggableCLArray(cla.Array, Taggable): def to_tagged_cl_array(ary: cla.Array, - axes: Optional[Tuple[Axis, ...]], - tags: FrozenSet[Tag]) -> TaggableCLArray: + axes: Optional[Tuple[Axis, ...]] = None, + tags: FrozenSet[Tag] = frozenset()) -> TaggableCLArray: """ Returns a :class:`TaggableCLArray` that is constructed from the data in *ary* along with the metadata from *axes* and *tags*. If *ary* is already a @@ -123,7 +132,7 @@ def to_tagged_cl_array(ary: cla.Array, array. If passed *None*, then initialized to a :class:`pytato.Axis` with no tags attached for each dimension. """ - if axes and len(axes) != ary.ndim: + if axes is not None and len(axes) != ary.ndim: raise ValueError("axes length does not match array dimension: " f"got {len(axes)} axes for {ary.ndim}d array") @@ -131,7 +140,7 @@ def to_tagged_cl_array(ary: cla.Array, tags = normalize_tags(tags) if isinstance(ary, TaggableCLArray): - if axes: + if axes is not None: for i, axis in enumerate(axes): ary = ary.with_tagged_axis(i, axis.tags) @@ -144,3 +153,45 @@ def to_tagged_cl_array(ary: cla.Array, **_unwrap_cl_array(ary)) else: raise TypeError(f"unsupported array type: '{type(ary).__name__}'") + +# }}} + + +# {{{ creation + +def empty(queue, shape, dtype=float, *, + axes: Optional[Tuple[Axis, ...]] = None, + tags: FrozenSet[Tag] = frozenset(), + order: str = "C", + allocator=None) -> TaggableCLArray: + if dtype is not None: + dtype = np.dtype(dtype) + + return TaggableCLArray( + queue, shape, dtype, + axes=axes, tags=tags, + order=order, allocator=allocator) + + +def zeros(queue, shape, dtype=float, *, + axes: Optional[Tuple[Axis, ...]] = None, + tags: FrozenSet[Tag] = frozenset(), + order: str = "C", + allocator=None) -> TaggableCLArray: + result = empty( + queue, shape, dtype=dtype, axes=axes, tags=tags, + order=order, allocator=allocator) + result._zero_fill() + + return result + + +def to_device(queue, ary, *, + axes: Optional[Tuple[Axis, ...]] = None, + tags: FrozenSet[Tag] = frozenset(), + allocator=None): + return to_tagged_cl_array( + cla.to_device(queue, ary, allocator=allocator), + axes=axes, tags=tags) + +# }}} -- GitLab