From d757876d875afe07c430b9602cb25f8e397c54a0 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Fri, 24 Jun 2022 20:58:46 +0300
Subject: [PATCH] add some helper functions to taggable_cl_array

---
 arraycontext/impl/pyopencl/__init__.py        | 75 +++++++------------
 .../impl/pyopencl/taggable_cl_array.py        | 65 ++++++++++++++--
 2 files changed, 86 insertions(+), 54 deletions(-)

diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 1467246..b1a086d 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -154,27 +154,16 @@ class PyOpenCLArrayContext(ArrayContext):
     # {{{ ArrayContext interface
 
     def empty(self, shape, dtype):
-        from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
-
-        return TaggableCLArray(self.queue, shape=shape, dtype=dtype,
-                               allocator=self.allocator)
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        return tga.empty(self.queue, shape, dtype, allocator=self.allocator)
 
     def zeros(self, shape, dtype):
-        import pyopencl.array as cl_array
-        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
-        return to_tagged_cl_array(cl_array.zeros(self.queue, shape=shape,
-                                                 dtype=dtype,
-                                                 allocator=self.allocator),
-                                  axes=None, tags=frozenset())
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        return tga.zeros(self.queue, shape, dtype, allocator=self.allocator)
 
     def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
-        import pyopencl.array as cl_array
-        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
-        return to_tagged_cl_array(cl_array
-                                  .to_device(self.queue,
-                                             array,
-                                             allocator=self.allocator),
-                                  axes=None, tags=frozenset())
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        return tga.to_device(self.queue, array, allocator=self.allocator)
 
     def to_numpy(self, array):
         if np.isscalar(array):
@@ -202,10 +191,9 @@ class PyOpenCLArrayContext(ArrayContext):
             if len(wait_event_queue) > self._wait_event_queue_length:
                 wait_event_queue.pop(0).wait()
 
-        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
         # FIXME: Inherit loopy tags for these arrays
-        return {name: to_tagged_cl_array(ary, axes=None, tags=frozenset())
-                for name, ary in result.items()}
+        return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}
 
     def freeze(self, array):
         import pyopencl.array as cl_array
@@ -222,24 +210,21 @@ class PyOpenCLArrayContext(ArrayContext):
                                   actx=None)
 
     def thaw(self, array):
-        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
-                                                                  to_tagged_cl_array)
-        import pyopencl.array as cl_array
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         def _rec_thaw(ary):
-            if isinstance(ary, TaggableCLArray):
+            if isinstance(ary, tga.TaggableCLArray):
                 return ary.with_queue(self.queue)
-            elif isinstance(ary, cl_array.Array):
+            elif isinstance(ary, self.array_types):
                 from warnings import warn
-                warn("Invoking PyOpenCLArrayContext.thaw with pyopencl.Array"
+                warn(f"Invoking PyOpenCLArrayContext.thaw with {type(ary).__name__}"
                     " will be unsupported in 2023. Use `to_tagged_cl_array`"
-                    " to convert instances of pyopencl.Array to TaggableCLArray.",
+                    " to convert instances to TaggableCLArray.",
                     DeprecationWarning, stacklevel=2)
-                return (to_tagged_cl_array(ary, axes=None, tags=frozenset())
-                        .with_queue(self.queue))
+                return (tga.to_tagged_cl_array(ary).with_queue(self.queue))
             else:
-                raise ValueError("array should be a cl.array.Array,"
-                                f" got '{type(ary)}'")
+                raise ValueError(
+                    f"array should be a cl.array.Array, got '{type(ary).__name__}'")
 
         return with_array_context(rec_map_array_container(_rec_thaw, array),
                                   actx=self)
@@ -317,35 +302,31 @@ class PyOpenCLArrayContext(ArrayContext):
         return t_unit
 
     def tag(self, tags: ToTagSetConvertible, array):
-        import pyopencl.array as cl_array
-        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
-                                                                  to_tagged_cl_array)
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         def _rec_tagged(ary):
-            if isinstance(ary, TaggableCLArray):
+            if isinstance(ary, tga.TaggableCLArray):
                 return ary.tagged(tags)
-            elif isinstance(ary, cl_array.Array):
-                return to_tagged_cl_array(ary, axes=None,  tags=tags)
+            elif isinstance(ary, self.array_types):
+                return tga.to_tagged_cl_array(ary, tags=tags)
             else:
-                raise ValueError("array should be a cl.array.Array,"
-                                 f" got '{type(ary)}'")
+                raise ValueError(
+                    f"array should be a cl.array.Array, got '{type(ary).__name__}'")
 
         return rec_map_array_container(_rec_tagged, array)
 
     def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
-        import pyopencl.array as cl_array
-        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
-                                                                  to_tagged_cl_array)
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         def _rec_tagged(ary):
-            if isinstance(ary, TaggableCLArray):
+            if isinstance(ary, tga.TaggableCLArray):
                 return ary.with_tagged_axis(iaxis, tags)
-            elif isinstance(ary, cl_array.Array):
-                return (to_tagged_cl_array(ary, axes=None,  tags=tags)
+            elif isinstance(ary, self.array_types):
+                return (tga.to_tagged_cl_array(ary, tags=tags)
                         .with_tagged_axis(iaxis, tags))
             else:
-                raise ValueError("array should be a cl.array.Array,"
-                                 f" got '{type(ary)}'")
+                raise ValueError(
+                    f"array should be a cl.array.Array, got '{type(ary).__name__}'")
 
         return rec_map_array_container(_rec_tagged, array)
 
diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py
index 439ca58..49ae08b 100644
--- a/arraycontext/impl/pyopencl/taggable_cl_array.py
+++ b/arraycontext/impl/pyopencl/taggable_cl_array.py
@@ -5,12 +5,17 @@
 .. autofunction:: to_tagged_cl_array
 """
 
-import pyopencl.array as cla
-from typing import Any, Dict, FrozenSet, Optional, Tuple
-from pytools.tag import Taggable, Tag, ToTagSetConvertible
 from dataclasses import dataclass
+from typing import Any, Dict, FrozenSet, Optional, Tuple
+
+import numpy as np
+import pyopencl.array as cla
+
 from pytools import memoize
+from pytools.tag import Taggable, Tag, ToTagSetConvertible
+
 
+# {{{ utils
 
 @dataclass(frozen=True, eq=True)
 class Axis(Taggable):
@@ -42,6 +47,10 @@ def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]:
                 _fast=True,
                 )
 
+# }}}
+
+
+# {{{ TaggableCLArray
 
 class TaggableCLArray(cla.Array, Taggable):
     """
@@ -111,8 +120,8 @@ class TaggableCLArray(cla.Array, Taggable):
 
 
 def to_tagged_cl_array(ary: cla.Array,
-                       axes: Optional[Tuple[Axis, ...]],
-                       tags: FrozenSet[Tag]) -> TaggableCLArray:
+                       axes: Optional[Tuple[Axis, ...]] = None,
+                       tags: FrozenSet[Tag] = frozenset()) -> TaggableCLArray:
     """
     Returns a :class:`TaggableCLArray` that is constructed from the data in
     *ary* along with the metadata from *axes* and *tags*. If *ary* is already a
@@ -123,7 +132,7 @@ def to_tagged_cl_array(ary: cla.Array,
         array. If passed *None*, then initialized to a :class:`pytato.Axis`
         with no tags attached for each dimension.
     """
-    if axes and len(axes) != ary.ndim:
+    if axes is not None and len(axes) != ary.ndim:
         raise ValueError("axes length does not match array dimension: "
                          f"got {len(axes)} axes for {ary.ndim}d array")
 
@@ -131,7 +140,7 @@ def to_tagged_cl_array(ary: cla.Array,
     tags = normalize_tags(tags)
 
     if isinstance(ary, TaggableCLArray):
-        if axes:
+        if axes is not None:
             for i, axis in enumerate(axes):
                 ary = ary.with_tagged_axis(i, axis.tags)
 
@@ -144,3 +153,45 @@ def to_tagged_cl_array(ary: cla.Array,
                                **_unwrap_cl_array(ary))
     else:
         raise TypeError(f"unsupported array type: '{type(ary).__name__}'")
+
+# }}}
+
+
+# {{{ creation
+
+def empty(queue, shape, dtype=float, *,
+        axes: Optional[Tuple[Axis, ...]] = None,
+        tags: FrozenSet[Tag] = frozenset(),
+        order: str = "C",
+        allocator=None) -> TaggableCLArray:
+    if dtype is not None:
+        dtype = np.dtype(dtype)
+
+    return TaggableCLArray(
+        queue, shape, dtype,
+        axes=axes, tags=tags,
+        order=order, allocator=allocator)
+
+
+def zeros(queue, shape, dtype=float, *,
+        axes: Optional[Tuple[Axis, ...]] = None,
+        tags: FrozenSet[Tag] = frozenset(),
+        order: str = "C",
+        allocator=None) -> TaggableCLArray:
+    result = empty(
+        queue, shape, dtype=dtype, axes=axes, tags=tags,
+        order=order, allocator=allocator)
+    result._zero_fill()
+
+    return result
+
+
+def to_device(queue, ary, *,
+        axes: Optional[Tuple[Axis, ...]] = None,
+        tags: FrozenSet[Tag] = frozenset(),
+        allocator=None):
+    return to_tagged_cl_array(
+        cla.to_device(queue, ary, allocator=allocator),
+        axes=axes, tags=tags)
+
+# }}}
-- 
GitLab