Skip to content
Snippets Groups Projects
Commit d757876d authored by Alexandru Fikl's avatar Alexandru Fikl Committed by Andreas Klöckner
Browse files

add some helper functions to taggable_cl_array

parent 9109363d
No related branches found
No related tags found
No related merge requests found
......@@ -154,27 +154,16 @@ class PyOpenCLArrayContext(ArrayContext):
# {{{ ArrayContext interface
def empty(self, shape, dtype):
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
return TaggableCLArray(self.queue, shape=shape, dtype=dtype,
allocator=self.allocator)
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.empty(self.queue, shape, dtype, allocator=self.allocator)
def zeros(self, shape, dtype):
import pyopencl.array as cl_array
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
return to_tagged_cl_array(cl_array.zeros(self.queue, shape=shape,
dtype=dtype,
allocator=self.allocator),
axes=None, tags=frozenset())
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.zeros(self.queue, shape, dtype, allocator=self.allocator)
def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
import pyopencl.array as cl_array
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
return to_tagged_cl_array(cl_array
.to_device(self.queue,
array,
allocator=self.allocator),
axes=None, tags=frozenset())
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.to_device(self.queue, array, allocator=self.allocator)
def to_numpy(self, array):
if np.isscalar(array):
......@@ -202,10 +191,9 @@ class PyOpenCLArrayContext(ArrayContext):
if len(wait_event_queue) > self._wait_event_queue_length:
wait_event_queue.pop(0).wait()
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
import arraycontext.impl.pyopencl.taggable_cl_array as tga
# FIXME: Inherit loopy tags for these arrays
return {name: to_tagged_cl_array(ary, axes=None, tags=frozenset())
for name, ary in result.items()}
return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}
def freeze(self, array):
import pyopencl.array as cl_array
......@@ -222,24 +210,21 @@ class PyOpenCLArrayContext(ArrayContext):
actx=None)
def thaw(self, array):
from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
to_tagged_cl_array)
import pyopencl.array as cl_array
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _rec_thaw(ary):
if isinstance(ary, TaggableCLArray):
if isinstance(ary, tga.TaggableCLArray):
return ary.with_queue(self.queue)
elif isinstance(ary, cl_array.Array):
elif isinstance(ary, self.array_types):
from warnings import warn
warn("Invoking PyOpenCLArrayContext.thaw with pyopencl.Array"
warn(f"Invoking PyOpenCLArrayContext.thaw with {type(ary).__name__}"
" will be unsupported in 2023. Use `to_tagged_cl_array`"
" to convert instances of pyopencl.Array to TaggableCLArray.",
" to convert instances to TaggableCLArray.",
DeprecationWarning, stacklevel=2)
return (to_tagged_cl_array(ary, axes=None, tags=frozenset())
.with_queue(self.queue))
return (tga.to_tagged_cl_array(ary).with_queue(self.queue))
else:
raise ValueError("array should be a cl.array.Array,"
f" got '{type(ary)}'")
raise ValueError(
f"array should be a cl.array.Array, got '{type(ary).__name__}'")
return with_array_context(rec_map_array_container(_rec_thaw, array),
actx=self)
......@@ -317,35 +302,31 @@ class PyOpenCLArrayContext(ArrayContext):
return t_unit
def tag(self, tags: ToTagSetConvertible, array):
import pyopencl.array as cl_array
from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
to_tagged_cl_array)
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _rec_tagged(ary):
if isinstance(ary, TaggableCLArray):
if isinstance(ary, tga.TaggableCLArray):
return ary.tagged(tags)
elif isinstance(ary, cl_array.Array):
return to_tagged_cl_array(ary, axes=None, tags=tags)
elif isinstance(ary, self.array_types):
return tga.to_tagged_cl_array(ary, tags=tags)
else:
raise ValueError("array should be a cl.array.Array,"
f" got '{type(ary)}'")
raise ValueError(
f"array should be a cl.array.Array, got '{type(ary).__name__}'")
return rec_map_array_container(_rec_tagged, array)
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
import pyopencl.array as cl_array
from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
to_tagged_cl_array)
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _rec_tagged(ary):
if isinstance(ary, TaggableCLArray):
if isinstance(ary, tga.TaggableCLArray):
return ary.with_tagged_axis(iaxis, tags)
elif isinstance(ary, cl_array.Array):
return (to_tagged_cl_array(ary, axes=None, tags=tags)
elif isinstance(ary, self.array_types):
return (tga.to_tagged_cl_array(ary, tags=tags)
.with_tagged_axis(iaxis, tags))
else:
raise ValueError("array should be a cl.array.Array,"
f" got '{type(ary)}'")
raise ValueError(
f"array should be a cl.array.Array, got '{type(ary).__name__}'")
return rec_map_array_container(_rec_tagged, array)
......
......@@ -5,12 +5,17 @@
.. autofunction:: to_tagged_cl_array
"""
import pyopencl.array as cla
from typing import Any, Dict, FrozenSet, Optional, Tuple
from pytools.tag import Taggable, Tag, ToTagSetConvertible
from dataclasses import dataclass
from typing import Any, Dict, FrozenSet, Optional, Tuple
import numpy as np
import pyopencl.array as cla
from pytools import memoize
from pytools.tag import Taggable, Tag, ToTagSetConvertible
# {{{ utils
@dataclass(frozen=True, eq=True)
class Axis(Taggable):
......@@ -42,6 +47,10 @@ def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]:
_fast=True,
)
# }}}
# {{{ TaggableCLArray
class TaggableCLArray(cla.Array, Taggable):
"""
......@@ -111,8 +120,8 @@ class TaggableCLArray(cla.Array, Taggable):
def to_tagged_cl_array(ary: cla.Array,
axes: Optional[Tuple[Axis, ...]],
tags: FrozenSet[Tag]) -> TaggableCLArray:
axes: Optional[Tuple[Axis, ...]] = None,
tags: FrozenSet[Tag] = frozenset()) -> TaggableCLArray:
"""
Returns a :class:`TaggableCLArray` that is constructed from the data in
*ary* along with the metadata from *axes* and *tags*. If *ary* is already a
......@@ -123,7 +132,7 @@ def to_tagged_cl_array(ary: cla.Array,
array. If passed *None*, then initialized to a :class:`pytato.Axis`
with no tags attached for each dimension.
"""
if axes and len(axes) != ary.ndim:
if axes is not None and len(axes) != ary.ndim:
raise ValueError("axes length does not match array dimension: "
f"got {len(axes)} axes for {ary.ndim}d array")
......@@ -131,7 +140,7 @@ def to_tagged_cl_array(ary: cla.Array,
tags = normalize_tags(tags)
if isinstance(ary, TaggableCLArray):
if axes:
if axes is not None:
for i, axis in enumerate(axes):
ary = ary.with_tagged_axis(i, axis.tags)
......@@ -144,3 +153,45 @@ def to_tagged_cl_array(ary: cla.Array,
**_unwrap_cl_array(ary))
else:
raise TypeError(f"unsupported array type: '{type(ary).__name__}'")
# }}}
# {{{ creation
def empty(queue, shape, dtype=float, *,
axes: Optional[Tuple[Axis, ...]] = None,
tags: FrozenSet[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
if dtype is not None:
dtype = np.dtype(dtype)
return TaggableCLArray(
queue, shape, dtype,
axes=axes, tags=tags,
order=order, allocator=allocator)
def zeros(queue, shape, dtype=float, *,
axes: Optional[Tuple[Axis, ...]] = None,
tags: FrozenSet[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
result = empty(
queue, shape, dtype=dtype, axes=axes, tags=tags,
order=order, allocator=allocator)
result._zero_fill()
return result
def to_device(queue, ary, *,
axes: Optional[Tuple[Axis, ...]] = None,
tags: FrozenSet[Tag] = frozenset(),
allocator=None):
return to_tagged_cl_array(
cla.to_device(queue, ary, allocator=allocator),
axes=axes, tags=tags)
# }}}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment