Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • inducer/arraycontext
  • kaushikcfd/arraycontext
  • fikl2/arraycontext
3 results
Show changes
Showing
with 3096 additions and 635 deletions
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autoclass:: PyOpenCLArrayContext .. autoclass:: PyOpenCLArrayContext
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -26,20 +29,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -26,20 +29,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from functools import partial, reduce
import operator import operator
from functools import partial, reduce
import numpy as np import numpy as np
from arraycontext.fake_numpy import \
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import ( from arraycontext.container.traversal import (
rec_map_array_container, rec_map_array_container,
rec_multimap_array_container, rec_map_reduce_array_container,
rec_map_reduce_array_container, rec_multimap_array_container,
rec_multimap_reduce_array_container, rec_multimap_reduce_array_container,
) )
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
try: try:
import pyopencl as cl # noqa: F401 import pyopencl as cl # noqa: F401
...@@ -50,141 +56,82 @@ except ImportError: ...@@ -50,141 +56,82 @@ except ImportError:
# {{{ fake numpy # {{{ fake numpy
class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
def _get_fake_numpy_linalg_namespace(self): def _get_fake_numpy_linalg_namespace(self):
return _PyOpenCLFakeNumpyLinalgNamespace(self._array_context) return _PyOpenCLFakeNumpyLinalgNamespace(self._array_context)
# {{{ comparisons # NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# FIXME: This should be documentation, not a comment.
# These are here mainly because some arrays may choose to interpret
# equality comparison as a binary predicate of structural identity,
# i.e. more like "are you two equal", and not like numpy semantics.
# These operations provide access to numpy-style comparisons in that
# case.
def equal(self, x, y):
return rec_multimap_array_container(operator.eq, x, y)
def not_equal(self, x, y):
return rec_multimap_array_container(operator.ne, x, y)
def greater(self, x, y):
return rec_multimap_array_container(operator.gt, x, y)
def greater_equal(self, x, y):
return rec_multimap_array_container(operator.ge, x, y)
def less(self, x, y):
return rec_multimap_array_container(operator.lt, x, y)
def less_equal(self, x, y):
return rec_multimap_array_container(operator.le, x, y)
# }}}
def ones_like(self, ary):
def _ones_like(subary):
ones = self._array_context.empty_like(subary)
ones.fill(1)
return ones
return self._new_like(ary, _ones_like)
def maximum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.maximum, queue=self._array_context.queue),
x, y)
def minimum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.minimum, queue=self._array_context.queue),
x, y)
def where(self, criterion, then, else_): # {{{ array creation routines
def where_inner(inner_crit, inner_then, inner_else):
if isinstance(inner_crit, bool):
return inner_then if inner_crit else inner_else
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
queue=self._array_context.queue)
return rec_multimap_array_container(where_inner, criterion, then, else_) def zeros(self, shape, dtype) -> TaggableCLArray:
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.zeros(self._array_context.queue, shape, dtype,
allocator=self._array_context.allocator)
def sum(self, a, axis=None, dtype=None): def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
"deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
"instead.",
DeprecationWarning, stacklevel=2)
if isinstance(axis, int): import arraycontext.impl.pyopencl.taggable_cl_array as tga
axis = axis, actx = self._array_context
def _rec_sum(ary): def _empty_like(array):
if axis not in [None, tuple(range(ary.ndim))]: return tga.empty(actx.queue, array.shape, array.dtype,
raise NotImplementedError(f"Sum over '{axis}' axes not supported.") allocator=actx.allocator, axes=array.axes, tags=array.tags)
return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue) return actx._rec_map_container(_empty_like, ary)
result = rec_map_reduce_array_container(sum, _rec_sum, a) def zeros_like(self, ary):
import arraycontext.impl.pyopencl.taggable_cl_array as tga
actx = self._array_context
if not self._array_context._force_device_scalars: def _zeros_like(array):
result = result.get()[()] return tga.zeros(
return result actx.queue, array.shape, array.dtype,
allocator=actx.allocator, axes=array.axes, tags=array.tags)
def min(self, a, axis=None): return actx._rec_map_container(_zeros_like, ary, default_scalar=0)
queue = self._array_context.queue
if isinstance(axis, int): def ones_like(self, ary):
axis = axis, return self.full_like(ary, 1)
def _rec_min(ary): def full_like(self, ary, fill_value):
if axis not in [None, tuple(range(ary.ndim))]: import arraycontext.impl.pyopencl.taggable_cl_array as tga
raise NotImplementedError(f"Min. over '{axis}' axes not supported.") actx = self._array_context
return cl_array.min(ary, queue=queue)
result = rec_map_reduce_array_container( def _full_like(subary):
partial(reduce, partial(cl_array.minimum, queue=queue)), filled = tga.empty(
_rec_min, actx.queue, subary.shape, subary.dtype,
a) allocator=actx.allocator, axes=subary.axes, tags=subary.tags)
filled.fill(fill_value)
if not self._array_context._force_device_scalars: return filled
result = result.get()[()]
return result
def max(self, a, axis=None): return actx._rec_map_container(_full_like, ary, default_scalar=fill_value)
queue = self._array_context.queue
if isinstance(axis, int): def copy(self, ary):
axis = axis, def _copy(subary):
return subary.copy(queue=self._array_context.queue)
def _rec_max(ary): return self._array_context._rec_map_container(_copy, ary)
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
return cl_array.max(ary, queue=queue)
result = rec_map_reduce_array_container( def arange(self, *args, **kwargs):
partial(reduce, partial(cl_array.maximum, queue=queue)), return cl_array.arange(self._array_context.queue, *args, **kwargs)
_rec_max,
a)
if not self._array_context._force_device_scalars: # }}}
result = result.get()[()]
return result
def stack(self, arrays, axis=0): # {{{ array manipulation routines
return rec_multimap_array_container(
lambda *args: cl_array.stack(arrays=args, axis=axis,
queue=self._array_context.queue),
*arrays)
def reshape(self, a, newshape, order="C"): def reshape(self, a, newshape, order="C"):
return rec_map_array_container( return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order), lambda ary: ary.reshape(newshape, order=order),
a) a)
def concatenate(self, arrays, axis=0):
return cl_array.concatenate(
arrays, axis,
self._array_context.queue,
self._array_context.allocator
)
def ravel(self, a, order="C"): def ravel(self, a, order="C"):
def _rec_ravel(a): def _rec_ravel(a):
if order in "FC": if order in "FC":
...@@ -207,67 +154,209 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): ...@@ -207,67 +154,209 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
return rec_map_array_container(_rec_ravel, a) return rec_map_array_container(_rec_ravel, a)
def concatenate(self, arrays, axis=0):
return cl_array.concatenate(
arrays, axis,
self._array_context.queue,
self._array_context.allocator
)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: cl_array.stack(arrays=args, axis=axis,
queue=self._array_context.queue),
*arrays)
# }}}
# {{{ linear algebra
def vdot(self, x, y, dtype=None): def vdot(self, x, y, dtype=None):
result = rec_multimap_reduce_array_container( return rec_multimap_reduce_array_container(
sum, sum,
partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue), partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
x, y) x, y)
if not self._array_context._force_device_scalars: # }}}
result = result.get()[()]
return result
def any(self, a):
queue = self._array_context.queue
result = rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
lambda subary: subary.any(queue=queue),
a)
if not self._array_context._force_device_scalars: # {{{ logic functions
result = result.get()[()]
return result
def all(self, a): def all(self, a):
queue = self._array_context.queue queue = self._array_context.queue
result = rec_map_reduce_array_container(
def _all(ary):
if np.isscalar(ary):
return np.int8(all([ary]))
return ary.all(queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)), partial(reduce, partial(cl_array.minimum, queue=queue)),
lambda subary: subary.all(queue=queue), _all,
a) a)
if not self._array_context._force_device_scalars: def any(self, a):
result = result.get()[()] queue = self._array_context.queue
return result
def _any(ary):
if np.isscalar(ary):
return np.int8(any([ary]))
return ary.any(queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_any,
a)
def array_equal(self, a, b): def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context actx = self._array_context
queue = actx.queue queue = actx.queue
# NOTE: pyopencl doesn't like `bool` much, so use `int8` instead # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead
false = actx.from_numpy(np.int8(False)) true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x, y): def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
if type(x) != type(y): if type(x) is not type(y):
return false return false_ary
try: try:
iterable = zip(serialize_container(x), serialize_container(y)) serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError: except NotAnArrayContainerError:
assert isinstance(x, cl_array.Array)
assert isinstance(y, cl_array.Array)
if x.shape != y.shape: if x.shape != y.shape:
return false return false_ary
else: else:
return (x == y).all() return (x == y).all()
else: else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce( return reduce(
partial(cl_array.minimum, queue=queue), partial(cl_array.minimum, queue=queue),
[rec_equal(ix, iy)for (_, ix), (_, iy) in iterable] [(true_ary if kx_i == ky_i else false_ary)
) and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return rec_equal(a, b)
# FIXME: This should be documentation, not a comment.
# These are here mainly because some arrays may choose to interpret
# equality comparison as a binary predicate of structural identity,
# i.e. more like "are you two equal", and not like numpy semantics.
# These operations provide access to numpy-style comparisons in that
# case.
def greater(self, x, y):
return rec_multimap_array_container(operator.gt, x, y)
def greater_equal(self, x, y):
return rec_multimap_array_container(operator.ge, x, y)
def less(self, x, y):
return rec_multimap_array_container(operator.lt, x, y)
def less_equal(self, x, y):
return rec_multimap_array_container(operator.le, x, y)
def equal(self, x, y):
return rec_multimap_array_container(operator.eq, x, y)
def not_equal(self, x, y):
return rec_multimap_array_container(operator.ne, x, y)
result = rec_equal(a, b) def logical_or(self, x, y):
if not self._array_context._force_device_scalars: return rec_multimap_array_container(cl_array.logical_or, x, y)
result = result.get()[()]
return result def logical_and(self, x, y):
return rec_multimap_array_container(cl_array.logical_and, x, y)
def logical_not(self, x):
return rec_map_array_container(cl_array.logical_not, x)
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
if isinstance(axis, int):
axis = axis,
def _rec_sum(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Sum over '{axis}' axes not supported.")
return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)
return rec_map_reduce_array_container(sum, _rec_sum, a)
def maximum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.maximum, queue=self._array_context.queue),
x, y)
def amax(self, a, axis=None):
queue = self._array_context.queue
if isinstance(axis, int):
axis = axis,
def _rec_max(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
return cl_array.max(ary, queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_rec_max,
a)
max = amax
def minimum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.minimum, queue=self._array_context.queue),
x, y)
def amin(self, a, axis=None):
queue = self._array_context.queue
if isinstance(axis, int):
axis = axis,
def _rec_min(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
return cl_array.min(ary, queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
_rec_min,
a)
min = amin
def absolute(self, a):
return self.abs(a)
# }}}
# {{{ sorting, searching, and counting
def where(self, criterion, then, else_):
def where_inner(inner_crit, inner_then, inner_else):
if isinstance(inner_crit, bool | np.bool_):
return inner_then if inner_crit else inner_else
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
queue=self._array_context.queue)
return rec_multimap_array_container(where_inner, criterion, then, else_)
# }}}
# }}} # }}}
......
"""
.. autoclass:: TaggableCLArray
.. autoclass:: Axis
.. autofunction:: to_tagged_cl_array
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import numpy as np
import pyopencl.array as cla
from pytools import memoize
from pytools.tag import Tag, Taggable, ToTagSetConvertible
# {{{ utils
@dataclass(frozen=True, eq=True)
class Axis(Taggable):
"""
Records the tags corresponding to a dimension of :class:`TaggableCLArray`.
"""
tags: frozenset[Tag]
def _with_new_tags(self, tags: frozenset[Tag]) -> Axis:
from dataclasses import replace
return replace(self, tags=tags)
@memoize
def _construct_untagged_axes(ndim: int) -> tuple[Axis, ...]:
return tuple(Axis(frozenset()) for _ in range(ndim))
def _unwrap_cl_array(ary: cla.Array) -> dict[str, Any]:
return {
"shape": ary.shape,
"dtype": ary.dtype,
"allocator": ary.allocator,
"strides": ary.strides,
"data": ary.base_data,
"offset": ary.offset,
"events": ary.events,
"_context": ary.context,
"_queue": ary.queue,
"_size": ary.size,
"_fast": True,
}
# }}}
# {{{ TaggableCLArray
class TaggableCLArray(cla.Array, Taggable):
"""
A :class:`pyopencl.array.Array` with additional metadata. This is used by
:class:`~arraycontext.PytatoPyOpenCLArrayContext` to preserve tags for data
while frozen, and also in a similar capacity by
:class:`~arraycontext.PyOpenCLArrayContext`.
.. attribute:: axes
A :class:`tuple` of instances of :class:`Axis`, with one :class:`Axis`
for each dimension of the array.
.. attribute:: tags
A :class:`frozenset` of :class:`pytools.tag.Tag`. Typically intended to
record application-specific metadata to drive the optimizations in
:meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
"""
def __init__(self, cq, shape, dtype, order="C", allocator=None,
data=None, offset=0, strides=None, events=None, _flags=None,
_fast=False, _size=None, _context=None, _queue=None,
axes=None, tags=frozenset()):
super().__init__(cq=cq, shape=shape, dtype=dtype,
order=order, allocator=allocator,
data=data, offset=offset,
strides=strides, events=events,
_flags=_flags, _fast=_fast,
_size=_size, _context=_context,
_queue=_queue)
if __debug__:
if not isinstance(tags, frozenset):
raise TypeError("tags are not a frozenset")
if axes is not None and len(axes) != self.ndim:
raise ValueError("axes length does not match array dimension: "
f"got {len(axes)} axes for {self.ndim}d array")
if axes is None:
axes = _construct_untagged_axes(self.ndim)
self.tags = tags
self.axes = axes
def __repr__(self) -> str:
return (f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype}, "
f"tags={self.tags}, axes={self.axes})")
def copy(self, queue=cla._copy_queue):
ary = super().copy(queue=queue)
return type(self)(None, tags=self.tags, axes=self.axes,
**_unwrap_cl_array(ary))
def _with_new_tags(self, tags: frozenset[Tag]) -> TaggableCLArray:
return type(self)(None, tags=tags, axes=self.axes,
**_unwrap_cl_array(self))
def with_tagged_axis(self, iaxis: int,
tags: ToTagSetConvertible) -> TaggableCLArray:
"""
Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
"""
new_axes = (self.axes[:iaxis]
+ (self.axes[iaxis].tagged(tags),)
+ self.axes[iaxis+1:])
return type(self)(None, tags=self.tags, axes=new_axes,
**_unwrap_cl_array(self))
def to_tagged_cl_array(ary: cla.Array,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset()) -> TaggableCLArray:
"""
Returns a :class:`TaggableCLArray` that is constructed from the data in
*ary* along with the metadata from *axes* and *tags*. If *ary* is already a
:class:`TaggableCLArray`, the new *tags* and *axes* are added to the
existing ones.
:arg axes: An instance of :class:`Axis` for each dimension of the
array. If passed *None*, then initialized to a :class:`pytato.Axis`
with no tags attached for each dimension.
"""
if axes is not None and len(axes) != ary.ndim:
raise ValueError("axes length does not match array dimension: "
f"got {len(axes)} axes for {ary.ndim}d array")
from pytools.tag import normalize_tags
tags = normalize_tags(tags)
if isinstance(ary, TaggableCLArray):
if axes is not None:
for i, axis in enumerate(axes):
ary = ary.with_tagged_axis(i, axis.tags)
if tags:
ary = ary.tagged(tags)
return ary
elif isinstance(ary, cla.Array):
return TaggableCLArray(None, tags=tags, axes=axes,
**_unwrap_cl_array(ary))
else:
raise TypeError(f"unsupported array type: '{type(ary).__name__}'")
# }}}
# {{{ creation
def empty(queue, shape, dtype=float, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
if dtype is not None:
dtype = np.dtype(dtype)
return TaggableCLArray(
queue, shape, dtype,
axes=axes, tags=tags,
order=order, allocator=allocator)
def zeros(queue, shape, dtype=float, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
result = empty(
queue, shape, dtype=dtype, axes=axes, tags=tags,
order=order, allocator=allocator)
result._zero_fill()
return result
def to_device(queue, ary, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
allocator=None):
return to_tagged_cl_array(
cla.to_device(queue, ary, allocator=allocator),
axes=axes, tags=tags)
# }}}
""" from __future__ import annotations
__doc__ = """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
A :mod:`pytato`-based array context defers the evaluation of an array until its A :mod:`pytato`-based array context defers the evaluation of an array until it is
frozen. The execution contexts for the evaluations are specific to an frozen. The execution contexts for the evaluations are specific to an
:class:`~arraycontext.ArrayContext` type. For ex. :class:`~arraycontext.ArrayContext` type. For example,
:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to :class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to
JIT-compile and execute the array expressions. JIT-compile and execute the array expressions.
Following :mod:`pytato`-based array context are provided: The following :mod:`pytato`-based array contexts are provided:
.. autoclass:: PytatoPyOpenCLArrayContext .. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext
Compiling a python callable Compiling a Python callable (Internal)
^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: arraycontext.impl.pytato.compile .. automodule:: arraycontext.impl.pytato.compile
Utils
^^^^^
.. automodule:: arraycontext.impl.pytato.utils
""" """
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -41,18 +51,193 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -41,18 +51,193 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from arraycontext.context import ArrayContext, _ScalarLike import abc
import sys
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
from typing import Any, Callable, Union, Sequence, TYPE_CHECKING
from pytools.tag import Tag from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)
from arraycontext.metadata import NameHint
if TYPE_CHECKING: if TYPE_CHECKING:
import loopy as lp
import pyopencl as cl
import pytato import pytato
if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import pyopencl as cl
import logging
logger = logging.getLogger(__name__)
# {{{ tag conversion
def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
tags = normalize_tags(tags)
name_hints = [tag for tag in tags if isinstance(tag, NameHint)]
if name_hints:
name_hint, = name_hints
from pytato.tags import PrefixNamed
prefix_nameds = [tag for tag in tags if isinstance(tag, PrefixNamed)]
class PytatoPyOpenCLArrayContext(ArrayContext): if prefix_nameds:
prefix_named, = prefix_nameds
from warnings import warn
warn("When converting a "
f"arraycontext.metadata.NameHint('{name_hint.name}') "
"to pytato.tags.PrefixNamed, "
f"PrefixNamed('{prefix_named.prefix}') "
"was already present.", stacklevel=1)
tags = (
(tags | frozenset({PrefixNamed(name_hint.name)}))
- {name_hint})
return tags
# }}}
class _NotOnlyDataWrappers(Exception): # noqa: N818
pass
# {{{ _BasePytatoArrayContext
class _BasePytatoArrayContext(ArrayContext, abc.ABC):
""" """
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent An abstract :class:`ArrayContext` that uses :mod:`pytato` data types to
represent.
.. automethod:: __init__
.. automethod:: transform_dag
.. automethod:: compile
"""
def __init__(
self, *,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
super().__init__()
import pytato as pt
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
self._dag_transform_cache: dict[
pt.DictOfNamedArrays,
tuple[pt.DictOfNamedArrays, str]] = {}
if compile_trace_callback is None:
def _compile_trace_callback(what, stage, ir):
pass
compile_trace_callback = _compile_trace_callback
self._compile_trace_callback = compile_trace_callback
def _get_fake_numpy_namespace(self):
from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
return PytatoFakeNumpyNamespace(self)
@abc.abstractproperty
def _frozen_array_types(self) -> tuple[type, ...]:
"""
Returns valid frozen array types for the array context.
"""
# {{{ compilation
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
*dag* (most likely to perform domain-specific optimizations). Every
:mod:`pytato` DAG that is compiled to a GPU-kernel is
passed through this routine.
:arg dag: An instance of :class:`pytato.DictOfNamedArrays`
:returns: A transformed version of *dag*.
"""
return dag
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)
return t_unit
@abc.abstractmethod
def einsum(self, spec, *args, arg_names=None, tagged=()):
pass
# }}}
# {{{ properties
@property
def permits_inplace_modification(self):
return False
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True
def get_target(self):
return None
# }}}
# }}}
# {{{ PytatoPyOpenCLArrayContext
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
the arrays targeting OpenCL for offloading operations. the arrays targeting OpenCL for offloading operations.
.. attribute:: queue .. attribute:: queue
...@@ -70,191 +255,682 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -70,191 +255,682 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
.. automethod:: compile .. automethod:: compile
""" """
def __init__(
self, queue: cl.CommandQueue, allocator=None, *,
use_memory_pool: bool | None = None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
# do not use: only for testing
_force_svm_arg_limit: int | None = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
if allocator is not None and use_memory_pool is not None:
raise TypeError("may not specify both allocator and use_memory_pool")
self.using_svm = None
if allocator is None:
from pyopencl.characterize import has_coarse_grain_buffer_svm
has_svm = has_coarse_grain_buffer_svm(queue.device)
if has_svm:
self.using_svm = True
def __init__(self, queue, allocator=None): from pyopencl.tools import SVMAllocator
allocator = SVMAllocator(queue.context, queue=queue)
if use_memory_pool:
from pyopencl.tools import SVMPool
allocator = SVMPool(allocator)
else:
self.using_svm = False
from pyopencl.tools import ImmediateAllocator
allocator = ImmediateAllocator(queue)
if use_memory_pool:
from pyopencl.tools import MemoryPool
allocator = MemoryPool(allocator)
else:
# Check whether the passed allocator allocates SVM
try:
from pyopencl import SVMPointer
mem = allocator(4)
if isinstance(mem, SVMPointer):
self.using_svm = True
else:
self.using_svm = False
except ImportError:
self.using_svm = False
import pyopencl.array as cla
import pytato as pt import pytato as pt
super().__init__() super().__init__(compile_trace_callback=compile_trace_callback)
self.queue = queue self.queue = queue
self.allocator = allocator self.allocator = allocator
self.array_types = (pt.Array, ) self.array_types = (pt.Array, cla.Array)
self._freeze_prg_cache = {}
# unused, but necessary to keep the context alive # unused, but necessary to keep the context alive
self.context = self.queue.context self.context = self.queue.context
def _get_fake_numpy_namespace(self): self._force_svm_arg_limit = _force_svm_arg_limit
from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
return PytatoFakeNumpyNamespace(self)
# {{{ ArrayContext interface @property
def _frozen_array_types(self) -> tuple[type, ...]:
import pyopencl.array as cla
return (cla.Array,)
def clone(self): def _rec_map_container(
return type(self)(self.queue, self.allocator) self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
import pytato as pt
def empty(self, shape, dtype): import arraycontext.impl.pyopencl.taggable_cl_array as tga
raise ValueError("PytatoPyOpenCLArrayContext does not support empty")
def zeros(self, shape, dtype): if allowed_types is None:
import pytato as pt allowed_types = (pt.Array, tga.TaggableCLArray)
return pt.zeros(shape, dtype)
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{func.__qualname__} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface
def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): def from_numpy(self, array):
import pytato as pt import pytato as pt
import pyopencl.array as cla
cl_array = cla.to_device(self.queue, array)
return pt.make_data_wrapper(cl_array)
def to_numpy(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga
if np.isscalar(array):
return array
cl_array = self.freeze(array) def _from_numpy(ary):
return cl_array.get(queue=self.queue) return pt.make_data_wrapper(
tga.to_device(self.queue, ary, allocator=self.allocator)
)
def call_loopy(self, program, **kwargs): return with_array_context(
import pyopencl.array as cla self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True),
from pytato.loopy import call_loopy actx=self)
entrypoint = program.default_entrypoint.name def to_numpy(self, array):
def _to_numpy(ary):
return ary.get(queue=self.queue)
return with_array_context(
self._rec_map_container(_to_numpy, self.freeze(array)),
actx=None)
@memoize_method
def get_target(self):
import pyopencl as cl
import pyopencl.characterize as cl_char
dev = self.queue.device
if (
self._force_svm_arg_limit is not None
or (
self.using_svm and dev.type & cl.device_type.GPU
and cl_char.has_coarse_grain_buffer_svm(dev))):
if dev.max_parameter_size == 4352:
# Nvidia devices and PTXAS declare a limit of 4352 bytes,
# which is incorrect. The CUDA documentation at
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#function-parameters
# mentions a limit of 4KB, which is also incorrect.
# As far as I can tell, the actual limit is around 4080
# bytes, at least on a K40. Reducing the limit further
# in order to be on the safe side.
# Note that the naming convention isn't super consistent
# for Nvidia GPUs, so that we only use the maximum
# parameter size to determine if it is an Nvidia GPU.
limit = 4096-200
from warnings import warn
warn("Running on an Nvidia GPU, reducing the argument "
f"size limit from 4352 to {limit}.", stacklevel=1)
else:
limit = dev.max_parameter_size
if self._force_svm_arg_limit is not None:
limit = self._force_svm_arg_limit
# thaw frozen arrays logger.info(
kwargs = {kw: (self.thaw(arg) if isinstance(arg, cla.Array) else arg) "limiting argument buffer size for %s to %d bytes",
for kw, arg in kwargs.items()} dev, limit)
return call_loopy(program, kwargs, entrypoint) from arraycontext.impl.pytato.utils import (
ArgSizeLimitingPytatoLoopyPyOpenCLTarget,
)
return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
else:
return super().get_target()
def freeze(self, array): def freeze(self, array):
import pytato as pt if np.isscalar(array):
return array
import pyopencl.array as cla import pyopencl.array as cla
import loopy as lp import pytato as pt
if isinstance(array, cla.Array): from arraycontext.container.traversal import rec_keyed_map_array_container
return array.with_queue(None) from arraycontext.impl.pyopencl.taggable_cl_array import (
if not isinstance(array, pt.Array): TaggableCLArray,
raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " to_tagged_cl_array,
f"non-pytato array of type '{type(array)}'") )
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
from arraycontext.impl.pytato.utils import (
_normalize_pt_expr,
get_cl_axes_from_pt_axes,
)
array_as_dict: dict[str, cla.Array | TaggableCLArray | pt.Array] = {}
key_to_frozen_subary: dict[str, TaggableCLArray] = {}
key_to_pt_arrays: dict[str, pt.Array] = {}
def _record_leaf_ary_in_dict(
key: tuple[Any, ...],
ary: cla.Array | TaggableCLArray | pt.Array) -> None:
key_str = "_ary" + _ary_container_key_stringifier(key)
array_as_dict[key_str] = ary
rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)
# {{{ remove any non pytato arrays from array_as_dict
for key, subary in array_as_dict.items():
if isinstance(subary, TaggableCLArray):
key_to_frozen_subary[key] = subary.with_queue(None)
elif isinstance(subary, self._frozen_array_types):
from warnings import warn
warn(f"Invoking {type(self).__name__}.freeze with"
f" {type(subary).__name__} will be unsupported in 2023. Use"
" `to_tagged_cl_array` to convert instances to TaggableCLArray.",
DeprecationWarning, stacklevel=2)
key_to_frozen_subary[key] = (
to_tagged_cl_array(subary.with_queue(None)))
elif isinstance(subary, pt.DataWrapper):
# trivial freeze.
key_to_frozen_subary[key] = to_tagged_cl_array(
subary.data,
axes=get_cl_axes_from_pt_axes(subary.axes),
tags=subary.tags)
elif isinstance(subary, pt.Array):
# Don't be tempted to take shortcuts here, e.g. for empty
# arrays, as this will inhibit metadata propagation that
# may happen in transform_dag below. See
# https://github.com/inducer/arraycontext/pull/167#issuecomment-1151877480
key_to_pt_arrays[key] = subary
else:
raise TypeError(
f"{type(self).__name__}.freeze invoked with an unsupported "
f"array type: got '{type(subary).__name__}', but expected one "
f"of {self.array_types}")
# {{{ early exit for 0-sized arrays # }}}
if array.size == 0: def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
return cla.empty(self.queue.context, key_str = "_ary" + _ary_container_key_stringifier(key)
shape=array.shape, return key_to_frozen_subary[key_str]
dtype=array.dtype,
allocator=self.allocator)
# }}} if not key_to_pt_arrays:
# all cl arrays => no need to perform any codegen
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
from arraycontext.impl.pytato.utils import _normalize_pt_expr
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
{"_actx_out": array}) key_to_pt_arrays)
normalized_expr, bound_arguments = _normalize_pt_expr( normalized_expr, bound_arguments = _normalize_pt_expr(
pt_dict_of_named_arrays) pt_dict_of_named_arrays)
try: try:
pt_prg = self._freeze_prg_cache[normalized_expr] pt_prg = self._freeze_prg_cache[normalized_expr]
except KeyError: except KeyError:
pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr), try:
options=lp.Options(return_dict=True, transformed_dag, function_name = (
no_numpy=True), self._dag_transform_cache[normalized_expr])
cl_device=self.queue.device) except KeyError:
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) transformed_dag = self.transform_dag(normalized_expr)
from pytato.tags import PrefixNamed
name_hint_tags = []
for subary in key_to_pt_arrays.values():
name_hint_tags.extend(subary.tags_of_type(PrefixNamed))
from pytools import common_prefix
name_hint = common_prefix([nh.prefix for nh in name_hint_tags])
# All name_hint_tags shared at least some common prefix.
function_name = f"frozen_{name_hint}" if name_hint else "frozen_result"
self._dag_transform_cache[normalized_expr] = (
transformed_dag, function_name)
from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
opts = _DEFAULT_LOOPY_OPTIONS
assert opts.return_dict
pt_prg = pt.generate_loopy(transformed_dag,
options=opts,
cl_device=self.queue.device,
function_name=function_name,
target=self.get_target()
).bind_to_context(self.context)
pt_prg = pt_prg.with_transformed_translation_unit(
self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg self._freeze_prg_cache[normalized_expr] = pt_prg
else:
transformed_dag, function_name = (
self._dag_transform_cache[normalized_expr])
assert len(pt_prg.bound_arguments) == 0 assert len(pt_prg.bound_arguments) == 0
evt, out_dict = pt_prg(self.queue, **bound_arguments) evt, out_dict = pt_prg(self.queue,
allocator=self.allocator,
**bound_arguments)
evt.wait() evt.wait()
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
key_to_frozen_subary = {
**key_to_frozen_subary,
**{k: to_tagged_cl_array(
v.with_queue(None),
axes=get_cl_axes_from_pt_axes(transformed_dag[k].expr.axes),
tags=transformed_dag[k].expr.tags)
for k, v in out_dict.items()}
}
return out_dict["_actx_out"].with_queue(None) return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
def thaw(self, array): def thaw(self, array):
import pytato as pt import pytato as pt
import pyopencl.array as cla
if not isinstance(array, cla.Array): import arraycontext.impl.pyopencl.taggable_cl_array as tga
raise TypeError("PytatoPyOpenCLArrayContext.thaw expects CL arrays, got " from .utils import get_pt_axes_from_cl_axes
f"{type(array)}")
def _thaw(ary):
return pt.make_data_wrapper(ary.with_queue(self.queue),
axes=get_pt_axes_from_cl_axes(ary.axes),
tags=ary.tags)
return pt.make_data_wrapper(array.with_queue(self.queue)) return with_array_context(
self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
actx=self)
def freeze_thaw(self, array):
import pytato as pt
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _ft(ary):
if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)):
return ary
else:
raise _NotOnlyDataWrappers()
try:
return with_array_context(
self._rec_map_container(_ft, array),
actx=self)
except _NotOnlyDataWrappers:
return super().freeze_thaw(array)
def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
return ary.tagged(_preprocess_array_tags(tags))
return self._rec_map_container(_tag, array)
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
def _tag_axis(ary):
return ary.with_tagged_axis(iaxis, tags)
return self._rec_map_container(_tag_axis, array)
# }}} # }}}
# {{{ compilation
def call_loopy(self, program, **kwargs):
import pytato as pt
from pytato.loopy import call_loopy
from pytato.scalar_expr import SCALAR_CLASSES
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
entrypoint = program.default_entrypoint.name
# {{{ preprocess args
processed_kwargs = {}
for kw, arg in sorted(kwargs.items()):
if isinstance(arg, (pt.Array, *SCALAR_CLASSES)):
pass
elif isinstance(arg, TaggableCLArray):
arg = self.thaw(arg)
else:
raise ValueError(f"call_loopy argument '{kw}' expected to be an"
" instance of 'pytato.Array', 'Number' or"
f"'TaggableCLArray', got '{type(arg)}'")
processed_kwargs[kw] = arg
# }}}
return call_loopy(program, processed_kwargs, entrypoint)
def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyCompilingFunctionCaller(self, f) return LazilyPyOpenCLCompilingFunctionCaller(self, f)
def transform_loopy_program(self, t_unit): def transform_dag(self, dag: pytato.DictOfNamedArrays
raise ValueError("PytatoPyOpenCLArrayContext does not implement " ) -> pytato.DictOfNamedArrays:
"transform_loopy_program. Sub-classes are supposed " import pytato as pt
"to implement it.") dag = pt.transform.materialize_with_mpms(dag)
return dag
def transform_dag(self, dag: "pytato.DictOfNamedArrays" def einsum(self, spec, *args, arg_names=None, tagged=()):
) -> "pytato.DictOfNamedArrays": import pytato as pt
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
*dag* (most likely to perform domain-specific optimizations). Every
:mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is
passed through this routine.
:arg dag: An instance of :class:`pytato.DictOfNamedArrays` import arraycontext.impl.pyopencl.taggable_cl_array as tga
:returns: A transformed version of *dag*.
if arg_names is None:
arg_names = (None,) * len(args)
def preprocess_arg(name, arg):
if isinstance(arg, tga.TaggableCLArray):
ary = self.thaw(arg)
elif isinstance(arg, self._frozen_array_types):
from warnings import warn
warn(f"Invoking {type(self).__name__}.einsum with"
f" {type(arg).__name__} will be unsupported in 2023. Use"
" `to_tagged_cl_array` to convert instances to TaggableCLArray.",
DeprecationWarning, stacklevel=2)
ary = self.thaw(tga.to_tagged_cl_array(arg))
elif isinstance(arg, pt.Array):
ary = arg
else:
raise TypeError(
f"{type(self).__name__}.einsum invoked with an unsupported "
f"array type: got '{type(arg).__name__}', but expected one "
f"of {self.array_types}")
if name is not None: # noqa: SIM102
# Tagging Placeholders with naming-related tags is pointless:
# They already have names. It's also counterproductive, as
# multiple placeholders with the same name that are not
# also the same object are not allowed, and this would produce
# a different Placeholder object of the same name.
if (not isinstance(ary, pt.Placeholder)
and not ary.tags_of_type(NameHint)):
ary = ary.tagged(NameHint(name))
return ary
return pt.einsum(spec, *[
preprocess_arg(name, arg)
for name, arg in zip(arg_names, args, strict=True)
]).tagged(_preprocess_array_tags(tagged))
def clone(self):
return type(self)(self.queue, self.allocator)
# }}}
# }}}
# {{{ PytatoJAXArrayContext
class PytatoJAXArrayContext(_BasePytatoArrayContext):
"""
An arraycontext that uses :mod:`pytato` to represent the thawed state of
the arrays and compiles the expressions using
:class:`pytato.target.python.JAXPythonTarget`.
"""
def __init__(self,
*,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
""" """
import jax.numpy as jnp
import pytato as pt import pytato as pt
super().__init__(compile_trace_callback=compile_trace_callback)
self.array_types = (pt.Array, jnp.ndarray)
dag = pt.transform.materialize_with_mpms(dag) @property
def _frozen_array_types(self) -> tuple[type, ...]:
import jax.numpy as jnp
return (jnp.ndarray, )
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
if allowed_types is None:
allowed_types = self.array_types
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{type(self).__name__}.{func.__name__[1:]} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return dag return rec_map_array_container(_wrapper, array)
def tag(self, tags: Union[Sequence[Tag], Tag], array): # {{{ ArrayContext interface
return array.tagged(tags)
def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): def from_numpy(self, array):
# TODO import jax
from warnings import warn import pytato as pt
warn("tagging PytatoPyOpenCLArrayContext's array axes: not yet implemented",
stacklevel=2) def _from_numpy(ary):
return array return pt.make_data_wrapper(jax.device_put(ary))
return with_array_context(
self._rec_map_container(_from_numpy, array, (np.ndarray,)),
actx=self)
def to_numpy(self, array):
import jax
def _to_numpy(ary):
return jax.device_get(ary)
return with_array_context(
self._rec_map_container(_to_numpy, self.freeze(array)),
actx=None)
def freeze(self, array):
if np.isscalar(array):
return array
import jax.numpy as jnp
import pytato as pt
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
array_as_dict: dict[str, jnp.ndarray | pt.Array] = {}
key_to_frozen_subary: dict[str, jnp.ndarray] = {}
key_to_pt_arrays: dict[str, pt.Array] = {}
def _record_leaf_ary_in_dict(key: tuple[Any, ...],
ary: jnp.ndarray | pt.Array) -> None:
key_str = "_ary" + _ary_container_key_stringifier(key)
array_as_dict[key_str] = ary
rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)
# {{{ remove any non pytato arrays from array_as_dict
for key, subary in array_as_dict.items():
if isinstance(subary, jnp.ndarray):
key_to_frozen_subary[key] = subary.block_until_ready()
elif isinstance(subary, pt.DataWrapper):
# trivial freeze.
key_to_frozen_subary[key] = subary.data.block_until_ready()
elif isinstance(subary, pt.Array):
key_to_pt_arrays[key] = subary
else:
raise TypeError(
f"{type(self).__name__}.freeze invoked with an unsupported "
f"array type: got '{type(subary).__name__}', but expected one "
f"of {self.array_types}")
# }}}
def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
if not key_to_pt_arrays:
# all cl arrays => no need to perform any codegen
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays)
transformed_dag = self.transform_dag(pt_dict_of_named_arrays)
pt_prg = pt.generate_jax(transformed_dag, jit=True)
out_dict = pt_prg()
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
key_to_frozen_subary = {
**key_to_frozen_subary,
**{k: v.block_until_ready()
for k, v in out_dict.items()}
}
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
def thaw(self, array):
import pytato as pt
def _thaw(ary):
return pt.make_data_wrapper(ary)
return with_array_context(
self._rec_map_container(_thaw, array, self._frozen_array_types),
actx=self)
def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyJAXCompilingFunctionCaller
return LazilyJAXCompilingFunctionCaller(self, f)
def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
import jax.numpy as jnp
if isinstance(ary, jnp.ndarray):
return ary
else:
return ary.tagged(_preprocess_array_tags(tags))
return self._rec_map_container(_tag, array)
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
def _tag_axis(ary):
import jax.numpy as jnp
if isinstance(ary, jnp.ndarray):
return ary
else:
return ary.with_tagged_axis(iaxis, tags)
return self._rec_map_container(_tag_axis, array)
# }}}
# {{{ compilation
def call_loopy(self, program, **kwargs):
raise NotImplementedError(
"Calling loopy on JAX arrays is not supported. Maybe rewrite"
" the loopy kernel as numpy-flavored array operations using"
" ArrayContext.np.")
def einsum(self, spec, *args, arg_names=None, tagged=()): def einsum(self, spec, *args, arg_names=None, tagged=()):
import pyopencl.array as cla
import pytato as pt import pytato as pt
if arg_names is None: if arg_names is None:
arg_names = (None,) * len(args) arg_names = (None,) * len(args)
def preprocess_arg(name, arg): def preprocess_arg(name, arg):
if isinstance(arg, cla.Array): import jax.numpy as jnp
if isinstance(arg, jnp.ndarray):
ary = self.thaw(arg) ary = self.thaw(arg)
else: elif isinstance(arg, pt.Array):
assert isinstance(arg, pt.Array)
ary = arg ary = arg
else:
raise TypeError(
f"{type(self).__name__}.einsum invoked with an unsupported "
f"array type: got '{type(arg).__name__}', but expected one "
f"of {self.array_types}")
if name is not None: if name is not None: # noqa: SIM102
from pytato.tags import PrefixNamed
# Tagging Placeholders with naming-related tags is pointless: # Tagging Placeholders with naming-related tags is pointless:
# They already have names. It's also counterproductive, as # They already have names. It's also counterproductive, as
# multiple placeholders with the same name that are not # multiple placeholders with the same name that are not
# also the same object are not allowed, and this would produce # also the same object are not allowed, and this would produce
# a different Placeholder object of the same name. # a different Placeholder object of the same name.
if not isinstance(ary, pt.Placeholder): if (not isinstance(ary, pt.Placeholder)
ary = ary.tagged(PrefixNamed(name)) and not ary.tags_of_type(NameHint)):
ary = ary.tagged(NameHint(name))
return ary return ary
return pt.einsum(spec, *[ return pt.einsum(spec, *[
preprocess_arg(name, arg) preprocess_arg(name, arg)
for name, arg in zip(arg_names, args) for name, arg in zip(arg_names, args, strict=True)
]) ]).tagged(_preprocess_array_tags(tagged))
@property def clone(self):
def permits_inplace_modification(self): return type(self)()
return False
@property # }}}
def supports_nonscalar_broadcasting(self):
return True
@property # }}}
def permits_advanced_indexing(self):
return True # vim: foldmethod=marker
""" """
.. currentmodule:: arraycontext.impl.pytato.compile .. autoclass:: BaseLazilyCompilingFunctionCaller
.. autoclass:: LazilyCompilingFunctionCaller .. autoclass:: LazilyPyOpenCLCompilingFunctionCaller
.. autoclass:: LazilyJAXCompilingFunctionCaller
.. autoclass:: CompiledFunction .. autoclass:: CompiledFunction
.. autoclass:: FromArrayContextCompile .. autoclass:: FromArrayContextCompile
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -28,27 +32,44 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -28,27 +32,44 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from arraycontext.container import ArrayContainer, is_array_container_type
from arraycontext import PytatoPyOpenCLArrayContext
from arraycontext.container.traversal import rec_keyed_map_array_container
import abc import abc
import numpy as np import itertools
from typing import Any, Callable, Tuple, Dict, Mapping import logging
from collections.abc import Callable, Hashable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pyrsistent import pmap, PMap from typing import Any
import numpy as np
from immutabledict import immutabledict
import pyopencl.array as cla
import pytato as pt import pytato as pt
import itertools from pytools import ProcessLogger, to_identifier
from pytools.tag import Tag from pytools.tag import Tag
from pytools import ProcessLogger from arraycontext.container import ArrayContainer, is_array_container_type
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.context import ArrayT
from arraycontext.impl.pytato import (
PytatoJAXArrayContext,
PytatoPyOpenCLArrayContext,
_BasePytatoArrayContext,
)
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _prg_id_to_kernel_name(f: Any) -> str:
if callable(f):
name = getattr(f, "__name__", "anonymous")
if not name.isidentifier():
return "actx_compiled_" + to_identifier(name)
else:
return name
else:
return to_identifier(str(f))
class FromArrayContextCompile(Tag): class FromArrayContextCompile(Tag):
""" """
Tagged to the entrypoint kernel of every translation unit that is generated Tagged to the entrypoint kernel of every translation unit that is generated
...@@ -64,7 +85,7 @@ class FromArrayContextCompile(Tag): ...@@ -64,7 +85,7 @@ class FromArrayContextCompile(Tag):
class AbstractInputDescriptor: class AbstractInputDescriptor:
""" """
Used internally in :class:`LazilyCompilingFunctionCaller` to characterize Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize
an input. an input.
""" """
def __eq__(self, other): def __eq__(self, other):
...@@ -87,9 +108,11 @@ class LeafArrayDescriptor(AbstractInputDescriptor): ...@@ -87,9 +108,11 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
# }}} # }}}
def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: # {{{ utilities
def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str:
""" """
Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Stringifies an Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
array-container's component's key. Goals of this routine: array-container's component's key. Goals of this routine:
* No two different keys should have the same stringification * No two different keys should have the same stringification
...@@ -97,7 +120,7 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: ...@@ -97,7 +120,7 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
* (informal) Shorter identifiers are preferred * (informal) Shorter identifiers are preferred
""" """
def _rec_str(key: Any) -> str: def _rec_str(key: Any) -> str:
if isinstance(key, (str, int)): if isinstance(key, str | int):
return str(key) return str(key)
elif isinstance(key, tuple): elif isinstance(key, tuple):
# t in '_actx_t': stands for tuple # t in '_actx_t': stands for tuple
...@@ -109,22 +132,20 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: ...@@ -109,22 +132,20 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
return "_".join(_rec_str(key) for key in keys) return "_".join(_rec_str(key) for key in keys)
def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...],
kwargs: Mapping[str, Any] kwargs: Mapping[str, Any]
) -> "Tuple[PMap[Tuple[Any, ...],\ ) -> \
Any],\ tuple[Mapping[tuple[Hashable, ...], Any],
PMap[Tuple[Any, ...],\ Mapping[tuple[Hashable, ...], AbstractInputDescriptor]]:
AbstractInputDescriptor]\
]":
""" """
Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Extracts Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts
mappings from argument id to argument values and from argument id to mappings from argument id to argument values and from argument id to
:class:`AbstractInputDescriptor`. See :class:`AbstractInputDescriptor`. See
:attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's
representation. representation.
""" """
arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} arg_id_to_arg: dict[tuple[Hashable, ...], Any] = {}
arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} arg_id_to_descr: dict[tuple[Hashable, ...], AbstractInputDescriptor] = {}
for kw, arg in itertools.chain(enumerate(args), for kw, arg in itertools.chain(enumerate(args),
kwargs.items()): kwargs.items()):
...@@ -134,10 +155,10 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], ...@@ -134,10 +155,10 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg))) arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
elif is_array_container_type(arg.__class__): elif is_array_container_type(arg.__class__):
def id_collector(keys, ary): def id_collector(keys, ary):
arg_id = (kw,) + keys arg_id = (kw, *keys) # noqa: B023
arg_id_to_arg[arg_id] = ary arg_id_to_arg[arg_id] = ary
arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(ary.dtype), arg_id_to_descr[arg_id] = LeafArrayDescriptor(
ary.shape) np.dtype(ary.dtype), ary.shape)
return ary return ary
rec_keyed_map_array_container(id_collector, arg) rec_keyed_map_array_container(id_collector, arg)
...@@ -151,38 +172,89 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], ...@@ -151,38 +172,89 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
" either a scalar, pt.Array or an array container. Got" " either a scalar, pt.Array or an array container. Got"
f" '{arg}'.") f" '{arg}'.")
return pmap(arg_id_to_arg), pmap(arg_id_to_descr) return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr)
def _get_f_placeholder_args(arg, kw, arg_id_to_name): def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
""" """
Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder`
in :meth:`LazilyCompilingFunctionCaller.__call__`.
Preprocessing here refers to:
- Metadata Inference that is supplied via *actx*\'s
:meth:`PytatoPyOpenCLArrayContext.transform_dag`.
"""
import pyopencl.array as cla
from arraycontext.impl.pyopencl.taggable_cl_array import (
TaggableCLArray,
to_tagged_cl_array,
)
if isinstance(ary, pt.Array):
dag = pt.make_dict_of_named_arrays({"_actx_out": ary})
# Transform the DAG to give metadata inference a chance to do its job
return actx.transform_dag(dag)["_actx_out"].expr
elif isinstance(ary, TaggableCLArray):
return ary
elif isinstance(ary, cla.Array):
from warnings import warn
warn("Passing pyopencl.array.Array to a compiled callable"
" is deprecated and will stop working in 2023."
" Use `to_tagged_cl_array` to convert the array to"
" TaggableCLArray", DeprecationWarning, stacklevel=2)
return to_tagged_cl_array(ary,
axes=None,
tags=frozenset())
else:
raise NotImplementedError(type(ary))
def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
"""
Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the
placeholder version of an argument to placeholder version of an argument to
:attr:`LazilyCompilingFunctionCaller.f`. :attr:`BaseLazilyCompilingFunctionCaller.f`.
""" """
if np.isscalar(arg): if np.isscalar(arg):
name = arg_id_to_name[(kw,)] from pytato.tags import ForceValueArgTag
return pt.make_placeholder(name, (), np.dtype(type(arg))) name = arg_id_to_name[kw,]
return pt.make_placeholder(name, (), np.dtype(type(arg)),
tags=frozenset({ForceValueArgTag()}))
elif isinstance(arg, pt.Array): elif isinstance(arg, pt.Array):
name = arg_id_to_name[(kw,)] name = arg_id_to_name[kw,]
return pt.make_placeholder(name, arg.shape, arg.dtype) # Transform the DAG to give metadata inference a chance to do its job
arg = _to_input_for_compiled(arg, actx)
return pt.make_placeholder(name, arg.shape, arg.dtype,
axes=arg.axes,
tags=arg.tags)
elif is_array_container_type(arg.__class__): elif is_array_container_type(arg.__class__):
def _rec_to_placeholder(keys, ary): def _rec_to_placeholder(keys, ary):
name = arg_id_to_name[(kw,) + keys] index = (kw, *keys)
return pt.make_placeholder(name, ary.shape, ary.dtype) name = arg_id_to_name[index]
# Transform the DAG to give metadata inference a chance to do its job
ary = _to_input_for_compiled(ary, actx)
return pt.make_placeholder(name,
ary.shape,
ary.dtype,
axes=ary.axes,
tags=ary.tags)
return rec_keyed_map_array_container(_rec_to_placeholder, arg) return rec_keyed_map_array_container(_rec_to_placeholder, arg)
else: else:
raise NotImplementedError(type(arg)) raise NotImplementedError(type(arg))
# }}}
# {{{ BaseLazilyCompilingFunctionCaller
@dataclass @dataclass
class LazilyCompilingFunctionCaller: class BaseLazilyCompilingFunctionCaller:
""" """
Records a side-effect-free callable Records a side-effect-free callable :attr:`f` that can be specialized for
:attr:`LazilyCompilingFunctionCaller.f` that can be specialized for the the input types with which :meth:`__call__` is invoked.
input types with which :meth:`LazilyCompilingFunctionCaller.__call__` is
invoked.
.. attribute:: f .. attribute:: f
...@@ -191,41 +263,26 @@ class LazilyCompilingFunctionCaller: ...@@ -191,41 +263,26 @@ class LazilyCompilingFunctionCaller:
.. automethod:: __call__ .. automethod:: __call__
""" """
actx: PytatoPyOpenCLArrayContext actx: _BasePytatoArrayContext
f: Callable[..., Any] f: Callable[..., Any]
program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor],
"CompiledFunction"] = field(default_factory=lambda: {}) CompiledFunction] = field(default_factory=lambda: {})
def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): # {{{ abstract interface
from pytato.target.loopy import BoundPyOpenCLProgram
import loopy as lp
with ProcessLogger(logger, "transform_dag"):
pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
with ProcessLogger(logger, "generate_loopy"):
pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
options=lp.Options(
return_dict=True,
no_numpy=True),
cl_device=self.actx.queue.device)
assert isinstance(pytato_program, BoundPyOpenCLProgram)
with ProcessLogger(logger, "transform_loopy_program"): def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
raise NotImplementedError
pytato_program = (pytato_program @property
.with_transformed_program( def compiled_function_returning_array_container_class(
lambda x: x.with_kernel( self) -> type[CompiledFunction]:
x.default_entrypoint raise NotImplementedError
.tagged(FromArrayContextCompile()))))
pytato_program = (pytato_program @property
.with_transformed_program(self def compiled_function_returning_array_class(self) -> type[CompiledFunction]:
.actx raise NotImplementedError
.transform_loopy_program))
return pytato_program # }}}
def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
input_id_to_name_in_program, output_id_to_name_in_program, input_id_to_name_in_program, output_id_to_name_in_program,
...@@ -234,30 +291,37 @@ class LazilyCompilingFunctionCaller: ...@@ -234,30 +291,37 @@ class LazilyCompilingFunctionCaller:
output_id = "_pt_out" output_id = "_pt_out"
dict_of_named_arrays = pt.make_dict_of_named_arrays( dict_of_named_arrays = pt.make_dict_of_named_arrays(
{output_id: ary_or_dict_of_named_arrays}) {output_id: ary_or_dict_of_named_arrays})
pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays) pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
return CompiledFunctionReturningArray( self._dag_to_transformed_pytato_prg(dict_of_named_arrays,
prg_id=self.f))
return self.compiled_function_returning_array_class(
self.actx, pytato_program, self.actx, pytato_program,
input_id_to_name_in_program=input_id_to_name_in_program, input_id_to_name_in_program=input_id_to_name_in_program,
output_name_in_program=output_id) output_tags=name_in_program_to_tags[output_id],
output_axes=name_in_program_to_axes[output_id],
output_name=output_id)
elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
pytato_program = self._dag_to_transformed_loopy_prg( pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
ary_or_dict_of_named_arrays) self._dag_to_transformed_pytato_prg(ary_or_dict_of_named_arrays,
return CompiledFunctionReturningArrayContainer( prg_id=self.f))
return self.compiled_function_returning_array_container_class(
self.actx, pytato_program, self.actx, pytato_program,
input_id_to_name_in_program=input_id_to_name_in_program, input_id_to_name_in_program=input_id_to_name_in_program,
output_id_to_name_in_program=output_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program,
name_in_program_to_tags=name_in_program_to_tags,
name_in_program_to_axes=name_in_program_to_axes,
output_template=output_template) output_template=output_template)
else: else:
raise NotImplementedError(type(ary_or_dict_of_named_arrays)) raise NotImplementedError(type(ary_or_dict_of_named_arrays))
def __call__(self, *args: Any, **kwargs: Any) -> Any: def __call__(self, *args: Any, **kwargs: Any) -> Any:
""" """
Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s
function application on *args*. function application on *args*.
Before applying :attr:`~LazilyCompilingFunctionCaller.f`, it is compiled Before applying :attr:`~BaseLazilyCompilingFunctionCaller.f`, it is compiled
to a :mod:`pytato` DAG that would apply to a :mod:`pytato` DAG that would apply
:attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
The intermediary pytato DAG for *args* is memoized in *self*. The intermediary pytato DAG for *args* is memoized in *self*.
""" """
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
...@@ -277,11 +341,16 @@ class LazilyCompilingFunctionCaller: ...@@ -277,11 +341,16 @@ class LazilyCompilingFunctionCaller:
for arg_id in arg_id_to_arg} for arg_id in arg_id_to_arg}
output_template = self.f( output_template = self.f(
*[_get_f_placeholder_args(arg, iarg, input_id_to_name_in_program) *[_get_f_placeholder_args(arg, iarg,
input_id_to_name_in_program, self.actx)
for iarg, arg in enumerate(args)], for iarg, arg in enumerate(args)],
**{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program) **{kw: _get_f_placeholder_args(arg, kw,
input_id_to_name_in_program,
self.actx)
for kw, arg in kwargs.items()}) for kw, arg in kwargs.items()})
self.actx._compile_trace_callback(self.f, "post_trace", output_template)
if (not (is_array_container_type(output_template.__class__) if (not (is_array_container_type(output_template.__class__)
or isinstance(output_template, pt.Array))): or isinstance(output_template, pt.Array))):
# TODO: We could possibly just short-circuit this interface if the # TODO: We could possibly just short-circuit this interface if the
...@@ -292,8 +361,7 @@ class LazilyCompilingFunctionCaller: ...@@ -292,8 +361,7 @@ class LazilyCompilingFunctionCaller:
f" but an instance of '{output_template.__class__}' instead.") f" but an instance of '{output_template.__class__}' instead.")
def _as_dict_of_named_arrays(keys, ary): def _as_dict_of_named_arrays(keys, ary):
name = "_pt_out_" + "_".join(str(key) name = "_pt_out_" + _ary_container_key_stringifier(keys)
for key in keys)
output_id_to_name_in_program[keys] = name output_id_to_name_in_program[keys] = name
dict_of_named_arrays[name] = ary dict_of_named_arrays[name] = ary
return ary return ary
...@@ -310,22 +378,188 @@ class LazilyCompilingFunctionCaller: ...@@ -310,22 +378,188 @@ class LazilyCompilingFunctionCaller:
self.program_cache[arg_id_to_descr] = compiled_func self.program_cache[arg_id_to_descr] = compiled_func
return compiled_func(arg_id_to_arg) return compiled_func(arg_id_to_arg)
# }}}
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): # {{{ LazilyPyOpenCLCompilingFunctionCaller
class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
actx: PytatoPyOpenCLArrayContext
@property
def compiled_function_returning_array_container_class(
self) -> type[CompiledFunction]:
return CompiledPyOpenCLFunctionReturningArrayContainer
@property
def compiled_function_returning_array_class(self) -> type[CompiledFunction]:
return CompiledPyOpenCLFunctionReturningArray
def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
if prg_id is None:
prg_id = self.f
from pytato.target.loopy import BoundPyOpenCLExecutable
self.actx._compile_trace_callback(
prg_id, "pre_transform_dag", dict_of_named_arrays)
with ProcessLogger(logger, f"transform_dag for '{prg_id}'"):
pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
self.actx._compile_trace_callback(
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
name_in_program_to_tags = {
name: out.tags
for name, out in pt_dict_of_named_arrays._data.items()}
name_in_program_to_axes = {
name: out.axes
for name, out in pt_dict_of_named_arrays._data.items()}
self.actx._compile_trace_callback(
prg_id, "pre_generate_loopy", pt_dict_of_named_arrays)
with ProcessLogger(logger, f"generate_loopy for '{prg_id}'"):
from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
opts = _DEFAULT_LOOPY_OPTIONS
assert opts.return_dict
pytato_program = pt.generate_loopy(
pt_dict_of_named_arrays,
options=opts,
function_name=_prg_id_to_kernel_name(prg_id),
target=self.actx.get_target(),
).bind_to_context(self.actx.context) # pylint: disable=no-member
assert isinstance(pytato_program, BoundPyOpenCLExecutable)
self.actx._compile_trace_callback(
prg_id, "post_generate_loopy", pytato_program)
self.actx._compile_trace_callback(
prg_id, "pre_transform_loopy_program", pytato_program)
with ProcessLogger(logger, f"transform_loopy_program for '{prg_id}'"):
pytato_program = (pytato_program
.with_transformed_translation_unit(
lambda x: x.with_kernel(
x.default_entrypoint
.tagged(FromArrayContextCompile()))))
pytato_program = (pytato_program
.with_transformed_translation_unit(
self.actx.transform_loopy_program))
self.actx._compile_trace_callback(
prg_id, "post_transform_loopy_program", pytato_program)
self.actx._compile_trace_callback(
prg_id, "final", pytato_program)
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
# }}}
# {{{ preserve back compat
class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller):
def __new__(cls, *args, **kwargs):
from warnings import warn
warn("LazilyCompilingFunctionCaller has been renamed to"
" LazilyPyOpenCLCompilingFunctionCaller. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return super().__new__(cls)
def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
from warnings import warn
warn("_dag_to_transformed_loopy_prg has been renamed to"
" _dag_to_transformed_pytato_prg. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return super()._dag_to_transformed_pytato_prg(dict_of_named_arrays)
# }}}
# {{{ LazilyJAXCompilingFunctionCaller
class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
@property
def compiled_function_returning_array_container_class(
self) -> type[CompiledFunction]:
return CompiledJAXFunctionReturningArrayContainer
@property
def compiled_function_returning_array_class(self) -> type[CompiledFunction]:
return CompiledJAXFunctionReturningArray
def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
if prg_id is None:
prg_id = self.f
self.actx._compile_trace_callback(
prg_id, "pre_transform_dag", dict_of_named_arrays)
with ProcessLogger(logger, f"transform_dag for '{prg_id}'"):
pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
self.actx._compile_trace_callback(
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
name_in_program_to_tags = {
name: out.tags
for name, out in pt_dict_of_named_arrays._data.items()}
name_in_program_to_axes = {
name: out.axes
for name, out in pt_dict_of_named_arrays._data.items()}
self.actx._compile_trace_callback(
prg_id, "pre_generate_jax", pt_dict_of_named_arrays)
with ProcessLogger(logger, f"generate_jax for '{prg_id}'"):
pytato_program = pt.generate_jax(
pt_dict_of_named_arrays,
jit=True,
function_name=_prg_id_to_kernel_name(prg_id))
self.actx._compile_trace_callback(
prg_id, "post_generate_jax", pytato_program)
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
input_kwargs_for_loopy = {} input_kwargs_for_loopy = {}
for arg_id, arg in arg_id_to_arg.items(): for arg_id, arg in arg_id_to_arg.items():
if np.isscalar(arg): if np.isscalar(arg):
arg = cla.to_device(actx.queue, np.array(arg)) if isinstance(actx, PytatoPyOpenCLArrayContext):
# Scalar kernel args are passed as lp.ValueArgs
pass
elif isinstance(actx, PytatoJAXArrayContext):
import jax
arg = jax.device_put(arg)
else:
raise NotImplementedError(type(actx))
elif isinstance(arg, pt.array.DataWrapper): elif isinstance(arg, pt.array.DataWrapper):
# got a Datwwrapper => simply gets its data # got a Datawrapper => simply gets its data
arg = arg.data arg = arg.data
elif isinstance(arg, cla.Array): elif isinstance(arg, actx._frozen_array_types):
# got a frozen array => do nothing # got a frozen array => do nothing
pass pass
elif isinstance(arg, pt.Array): elif isinstance(arg, pt.Array):
# got an array expression => evaluate it # got an array expression => evaluate it
arg = actx.freeze(arg).with_queue(actx.queue) from warnings import warn
warn(f"Argument array '{arg_id}' to a compiled function is "
"unevaluated. Evaluating just-in-time, at "
"considerable expense. This is deprecated and will stop "
"working in 2023. To avoid this warning, force evaluation "
"of all arguments via freeze/thaw.",
DeprecationWarning, stacklevel=4)
arg = actx.freeze(arg)
else: else:
raise NotImplementedError(type(arg)) raise NotImplementedError(type(arg))
...@@ -334,10 +568,23 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): ...@@ -334,10 +568,23 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
return input_kwargs_for_loopy return input_kwargs_for_loopy
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
from warnings import warn
warn("_args_to_cl_buffer has been renamed to"
" _args_to_device_buffers. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return _args_to_device_buffers(actx, input_id_to_name_in_program,
arg_id_to_arg)
# }}}
# {{{ compiled function
class CompiledFunction(abc.ABC): class CompiledFunction(abc.ABC):
""" """
A callable which captures the :class:`pytato.target.BoundProgram` resulting A callable which captures the :class:`pytato.target.BoundProgram` resulting
from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of from calling :attr:`~BaseLazilyCompilingFunctionCaller.f` with a given set of
input types, and generating :mod:`loopy` IR from it. input types, and generating :mod:`loopy` IR from it.
.. attribute:: pytato_program .. attribute:: pytato_program
...@@ -346,7 +593,7 @@ class CompiledFunction(abc.ABC): ...@@ -346,7 +593,7 @@ class CompiledFunction(abc.ABC):
A mapping from input id to the placeholder name in A mapping from input id to the placeholder name in
:attr:`CompiledFunction.pytato_program`. Input id is represented as the :attr:`CompiledFunction.pytato_program`. Input id is represented as the
position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented position of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s argument augmented
with the leaf array's key if the argument is an array container. with the leaf array's key if the argument is an array container.
...@@ -362,9 +609,13 @@ class CompiledFunction(abc.ABC): ...@@ -362,9 +609,13 @@ class CompiledFunction(abc.ABC):
""" """
pass pass
# }}}
# {{{ compiled pyopencl function
@dataclass(frozen=True) @dataclass(frozen=True)
class CompiledFunctionReturningArrayContainer(CompiledFunction): class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
""" """
.. attribute:: output_id_to_name_in_program .. attribute:: output_id_to_name_in_program
...@@ -381,12 +632,17 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): ...@@ -381,12 +632,17 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
""" """
actx: PytatoPyOpenCLArrayContext actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer output_template: ArrayContainer
def __call__(self, arg_id_to_arg) -> ArrayContainer: def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_cl_buffers( from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg) self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
evt, out_dict = self.pytato_program(queue=self.actx.queue, evt, out_dict = self.pytato_program(queue=self.actx.queue,
...@@ -399,14 +655,19 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): ...@@ -399,14 +655,19 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
evt.wait() evt.wait()
def to_output_template(keys, _): def to_output_template(keys, _):
return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]]) name_in_program = self.output_id_to_name_in_program[keys]
return self.actx.thaw(to_tagged_cl_array(
out_dict[name_in_program],
axes=get_cl_axes_from_pt_axes(
self.name_in_program_to_axes[name_in_program]),
tags=self.name_in_program_to_tags[name_in_program]))
return rec_keyed_map_array_container(to_output_template, return rec_keyed_map_array_container(to_output_template,
self.output_template) self.output_template)
@dataclass(frozen=True) @dataclass(frozen=True)
class CompiledFunctionReturningArray(CompiledFunction): class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
""" """
.. attribute:: output_name_in_program .. attribute:: output_name_in_program
...@@ -414,11 +675,16 @@ class CompiledFunctionReturningArray(CompiledFunction): ...@@ -414,11 +675,16 @@ class CompiledFunctionReturningArray(CompiledFunction):
""" """
actx: PytatoPyOpenCLArrayContext actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
output_name: str output_name: str
def __call__(self, arg_id_to_arg) -> ArrayContainer: def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_cl_buffers( from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg) self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
evt, out_dict = self.pytato_program(queue=self.actx.queue, evt, out_dict = self.pytato_program(queue=self.actx.queue,
...@@ -430,4 +696,78 @@ class CompiledFunctionReturningArray(CompiledFunction): ...@@ -430,4 +696,78 @@ class CompiledFunctionReturningArray(CompiledFunction):
# running out of memory. This mitigates that risk a bit, for now. # running out of memory. This mitigates that risk a bit, for now.
evt.wait() evt.wait()
return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
axes=get_cl_axes_from_pt_axes(
self.output_axes),
tags=self.output_tags))
# }}}
# {{{ compiled jax function
@dataclass(frozen=True)
class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
"""
.. attribute:: output_id_to_name_in_program
A mapping from output id to the name of
:class:`pytato.array.NamedArray` in
:attr:`CompiledFunction.pytato_program`. Output id is represented by
the key of a leaf array in the array container
:attr:`CompiledFunction.output_template`.
.. attribute:: output_template
An instance of :class:`arraycontext.ArrayContainer` that is the return
type of the callable.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer
def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
out_dict = self.pytato_program(**input_kwargs_for_loopy)
def to_output_template(keys, _):
return self.actx.thaw(
out_dict[self.output_id_to_name_in_program[keys]]
.block_until_ready()
)
return rec_keyed_map_array_container(to_output_template,
self.output_template)
@dataclass(frozen=True)
class CompiledJAXFunctionReturningArray(CompiledFunction):
"""
.. attribute:: output_name_in_program
Name of the output array in the program.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
output_name: str
def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
_evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)
return self.actx.thaw(out_dict[self.output_name]) return self.actx.thaw(out_dict[self.output_name])
# }}}
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees Copyright (C) 2021 University of Illinois Board of Trustees
""" """
...@@ -22,19 +25,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -22,19 +25,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from functools import partial, reduce from functools import partial, reduce
from typing import Any, cast
import numpy as np import numpy as np
from arraycontext.fake_numpy import ( import pytato as pt
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
)
from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import ( from arraycontext.container.traversal import (
rec_map_array_container, rec_map_array_container,
rec_multimap_array_container, rec_map_reduce_array_container,
rec_map_reduce_array_container, rec_multimap_array_container,
) )
import pytato as pt from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
...@@ -42,7 +47,7 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): ...@@ -42,7 +47,7 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
pass pass
class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
""" """
A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`. A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
...@@ -51,96 +56,74 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): ...@@ -51,96 +56,74 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
:mod:`pytato` does not define any memory layout for its arrays. See :mod:`pytato` does not define any memory layout for its arrays. See
:ref:`Pytato docs <pytato:memory-layout>` for more on this. :ref:`Pytato docs <pytato:memory-layout>` for more on this.
""" """
_pt_unary_funcs = frozenset({
"sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10",
"sqrt", "abs", "isnan", "real", "imag", "conj",
"logical_not",
})
_pt_multi_ary_funcs = frozenset({
"arctan2", "equal", "greater", "greater_equal", "less", "less_equal",
"not_equal", "minimum", "maximum", "where", "logical_and", "logical_or",
})
def _get_fake_numpy_linalg_namespace(self): def _get_fake_numpy_linalg_namespace(self):
return PytatoFakeNumpyLinalgNamespace(self._array_context) return PytatoFakeNumpyLinalgNamespace(self._array_context)
def __getattr__(self, name): def __getattr__(self, name):
if name in self._pt_unary_funcs:
pt_funcs = ["abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",
"sqrt", "exp"]
if name in pt_funcs:
from functools import partial from functools import partial
return partial(rec_map_array_container, getattr(pt, name)) return partial(rec_map_array_container, getattr(pt, name))
return super().__getattr__(name) if name in self._pt_multi_ary_funcs:
from functools import partial
def reshape(self, a, newshape, order="C"): return partial(rec_multimap_array_container, getattr(pt, name))
return rec_map_array_container(
lambda ary: pt.reshape(a, newshape, order=order),
a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(pt.transpose, a, axes)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(pt.concatenate, arrays, axis)
def ones_like(self, ary):
def _ones_like(subary):
return pt.ones(subary.shape, subary.dtype)
return self._new_like(ary, _ones_like)
def maximum(self, x, y):
return rec_multimap_array_container(pt.maximum, x, y)
def minimum(self, x, y):
return rec_multimap_array_container(pt.minimum, x, y)
def where(self, criterion, then, else_): return super().__getattr__(name)
return rec_multimap_array_container(pt.where, criterion, then, else_)
def sum(self, a, axis=None, dtype=None): # NOTE: the order of these follows the order in numpy docs
def _pt_sum(ary): # NOTE: when adding a function here, also add it to `array_context.rst` docs!
if dtype not in [ary.dtype, None]:
raise NotImplementedError
return pt.sum(ary, axis=axis) # {{{ array creation routines
return rec_map_reduce_array_container(sum, _pt_sum, a) def zeros(self, shape, dtype):
return pt.zeros(shape, dtype)
def min(self, a, axis=None): def zeros_like(self, ary):
return rec_map_reduce_array_container( def _zeros_like(array):
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) return self._array_context.zeros(
array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)
def max(self, a, axis=None): return self._array_context._rec_map_container(
return rec_map_reduce_array_container( _zeros_like, ary, default_scalar=0)
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
def stack(self, arrays, axis=0): def ones_like(self, ary):
return rec_multimap_array_container( return self.full_like(ary, 1)
lambda *args: pt.stack(arrays=args, axis=axis),
*arrays)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(pt.broadcast_to, shape=shape), array)
# {{{ relational operators
def equal(self, x, y):
return rec_multimap_array_container(pt.equal, x, y)
def not_equal(self, x, y): def full_like(self, ary, fill_value):
return rec_multimap_array_container(pt.not_equal, x, y) def _full_like(subary):
return pt.full(subary.shape, fill_value, subary.dtype).copy(
axes=subary.axes, tags=subary.tags)
def greater(self, x, y): return self._array_context._rec_map_container(
return rec_multimap_array_container(pt.greater, x, y) _full_like, ary, default_scalar=fill_value)
def greater_equal(self, x, y): def arange(self, *args: Any, **kwargs: Any):
return rec_multimap_array_container(pt.greater_equal, x, y) return pt.arange(*args, **kwargs)
def less(self, x, y): def full(self, shape, fill_value, dtype=None):
return rec_multimap_array_container(pt.less, x, y) return pt.full(shape, fill_value, dtype)
def less_equal(self, x, y): # }}}
return rec_multimap_array_container(pt.less_equal, x, y)
def conj(self, x): # {{{ array manipulation routines
return rec_multimap_array_container(pt.conj, x)
def arctan2(self, y, x): def reshape(self, a, newshape, order="C"):
return rec_multimap_array_container(pt.arctan2, y, x) return rec_map_array_container(
lambda ary: pt.reshape(a, newshape, order=order),
a)
def ravel(self, a, order="C"): def ravel(self, a, order="C"):
""" """
...@@ -164,39 +147,99 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): ...@@ -164,39 +147,99 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
return rec_map_array_container(_rec_ravel, a) return rec_map_array_container(_rec_ravel, a)
def any(self, a): def transpose(self, a, axes=None):
return rec_map_reduce_array_container( return rec_multimap_array_container(pt.transpose, a, axes)
partial(reduce, pt.logical_or),
lambda subary: pt.any(subary), a) def broadcast_to(self, array, shape):
return rec_map_array_container(partial(pt.broadcast_to, shape=shape), array)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(pt.concatenate, arrays, axis)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: pt.stack(arrays=args, axis=axis),
*arrays)
# }}}
# {{{ logic functions
def all(self, a): def all(self, a):
return rec_map_reduce_array_container( return rec_map_reduce_array_container(
partial(reduce, pt.logical_and), partial(reduce, pt.logical_and),
lambda subary: pt.all(subary), a) lambda subary: pt.all(subary), a)
def array_equal(self, a, b): def any(self, a):
return rec_map_reduce_array_container(
partial(reduce, pt.logical_or),
lambda subary: pt.any(subary), a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context actx = self._array_context
# NOTE: not all backends support `bool` properly, so use `int8` instead # NOTE: not all backends support `bool` properly, so use `int8` instead
false = actx.from_numpy(np.int8(False)) true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x, y): def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array:
if type(x) != type(y): if type(x) is not type(y):
return false return false_ary
try: try:
iterable = zip(serialize_container(x), serialize_container(y)) serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError: except NotAnArrayContainerError:
assert isinstance(x, pt.Array)
assert isinstance(y, pt.Array)
if x.shape != y.shape: if x.shape != y.shape:
return false return false_ary
else: else:
return pt.all(pt.equal(x, y)) return pt.all(cast(pt.Array, pt.equal(x, y)))
else: else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce( return reduce(
pt.logical_and, pt.logical_and,
[rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] [(true_ary if kx_i == ky_i else false_ary)
) and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return cast(Array, rec_equal(a, b))
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
def _pt_sum(ary):
if dtype not in [ary.dtype, None]:
raise NotImplementedError
return pt.sum(ary, axis=axis)
return rec_map_reduce_array_container(sum, _pt_sum, a)
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
max = amax
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
min = amin
def absolute(self, a):
return self.abs(a)
return rec_equal(a, b) def vdot(self, a: Array, b: Array):
return rec_multimap_array_container(pt.vdot, a, b)
# }}} # }}}
from __future__ import annotations
__doc__ = """
.. autofunction:: transfer_from_numpy
.. autofunction:: transfer_to_numpy
"""
__copyright__ = """ __copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees Copyright (C) 2021 University of Illinois Board of Trustees
""" """
...@@ -23,11 +32,29 @@ THE SOFTWARE. ...@@ -23,11 +32,29 @@ THE SOFTWARE.
""" """
from typing import Any, Dict, Set, Tuple, Mapping from collections.abc import Mapping
from pytato.array import SizeParam, Placeholder, make_placeholder from typing import TYPE_CHECKING, Any, cast
from pytato.array import Array, DataWrapper, DictOfNamedArrays
from pytato.transform import CopyMapper from pytato.array import (
from pytools import UniqueNameGenerator AbstractResultWithNamedArrays,
Array,
Axis as PtAxis,
DataWrapper,
DictOfNamedArrays,
Placeholder,
SizeParam,
make_placeholder,
)
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.transform import ArrayOrNames, CopyMapper
from pytools import UniqueNameGenerator, memoize_method
from arraycontext import ArrayContext
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
if TYPE_CHECKING:
import loopy as lp
class _DatawrapperToBoundPlaceholderMapper(CopyMapper): class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
...@@ -38,9 +65,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): ...@@ -38,9 +65,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
""" """
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.bound_arguments: Dict[str, Any] = {} self.bound_arguments: dict[str, Any] = {}
self.vng = UniqueNameGenerator() self.vng = UniqueNameGenerator()
self.seen_inputs: Set[str] = set() self.seen_inputs: set[str] = set()
def map_data_wrapper(self, expr: DataWrapper) -> Array: def map_data_wrapper(self, expr: DataWrapper) -> Array:
if expr.name is not None: if expr.name is not None:
...@@ -49,14 +76,16 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): ...@@ -49,14 +76,16 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
f"{expr.name} => Illegal.") f"{expr.name} => Illegal.")
self.seen_inputs.add(expr.name) self.seen_inputs.add(expr.name)
# Normalizing names so that we more arrays can have the normalized DAG. # Normalizing names so that more arrays can have the same normalized DAG.
name = self.vng("_actx_dw") from pytato.codegen import _generate_name_for_temp
name = _generate_name_for_temp(expr, self.vng, "_actx_dw")
self.bound_arguments[name] = expr.data self.bound_arguments[name] = expr.data
return make_placeholder( return make_placeholder(
name=name, name=name,
shape=tuple(self.rec(s) if isinstance(s, Array) else s shape=tuple(cast(Array, self.rec(s)) if isinstance(s, Array) else s
for s in expr.shape), for s in expr.shape),
dtype=expr.dtype, dtype=expr.dtype,
axes=expr.axes,
tags=expr.tags) tags=expr.tags)
def map_size_param(self, expr: SizeParam) -> Array: def map_size_param(self, expr: SizeParam) -> Array:
...@@ -67,8 +96,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): ...@@ -67,8 +96,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
" DatawrapperToBoundPlaceholderMapper.") " DatawrapperToBoundPlaceholderMapper.")
def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, def _normalize_pt_expr(
Mapping[str, Any]]: expr: DictOfNamedArrays
) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]:
""" """
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
normalized form of *expr*, with all instances of normalized form of *expr*, with all instances of
...@@ -80,4 +110,115 @@ def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, ...@@ -80,4 +110,115 @@ def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays,
""" """
normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
normalized_expr = normalize_mapper(expr) normalized_expr = normalize_mapper(expr)
assert isinstance(normalized_expr, AbstractResultWithNamedArrays)
return normalized_expr, normalize_mapper.bound_arguments return normalized_expr, normalize_mapper.bound_arguments
def get_pt_axes_from_cl_axes(axes: tuple[ClAxis, ...]) -> tuple[PtAxis, ...]:
return tuple(PtAxis(axis.tags) for axis in axes)
def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]:
return tuple(ClAxis(axis.tags) for axis in axes)
# {{{ arg-size-limiting loopy target
class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
def __init__(self, limit_arg_size_nbytes: int) -> None:
super().__init__()
self.limit_arg_size_nbytes = limit_arg_size_nbytes
@memoize_method
def get_loopy_target(self) -> lp.PyOpenCLTarget:
from loopy import PyOpenCLTarget
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)
# }}}
# {{{ Transfer mappers
class TransferFromNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx
def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np
if not isinstance(expr.data, np.ndarray):
raise ValueError("TransferFromNumpyMapper: tried to transfer data that "
"is already on the device")
# Ideally, this code should just do
# return self.actx.from_numpy(expr.data).tagged(expr.tags),
# but there seems to be no way to transfer the non_equality_tags in that case.
actx_ary = self.actx.from_numpy(expr.data)
assert isinstance(actx_ary, DataWrapper)
# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=actx_ary.data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
class TransferToNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx
def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np
import arraycontext.impl.pyopencl.taggable_cl_array as tga
if not isinstance(expr.data, tga.TaggableCLArray):
raise ValueError("TransferToNumpyMapper: tried to transfer data that "
"is already on the host")
np_data = self.actx.to_numpy(expr.data)
assert isinstance(np_data, np.ndarray)
# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
# type-ignore: discussed at
# https://github.com/inducer/arraycontext/pull/289#discussion_r1855523967
# possibly related: https://github.com/python/mypy/issues/17375
return DataWrapper( # type: ignore[call-arg]
data=np_data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
return TransferFromNumpyMapper(actx)(expr)
def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
return TransferToNumpyMapper(actx)(expr)
# }}}
# vim: foldmethod=marker
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autofunction:: make_loopy_program .. autofunction:: make_loopy_program
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -27,8 +29,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -27,8 +29,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from collections.abc import Mapping
from typing import ClassVar
import numpy as np
import loopy as lp import loopy as lp
from loopy.version import MOST_RECENT_LANGUAGE_VERSION from loopy.version import MOST_RECENT_LANGUAGE_VERSION
from pytools import memoize_in
from arraycontext.container.traversal import multimapped_over_array_containers
from arraycontext.fake_numpy import BaseFakeNumpyNamespace
# {{{ loopy # {{{ loopy
...@@ -64,9 +75,96 @@ def get_default_entrypoint(t_unit): ...@@ -64,9 +75,96 @@ def get_default_entrypoint(t_unit):
except AttributeError: except AttributeError:
try: try:
return t_unit.root_kernel return t_unit.root_kernel
except AttributeError: except AttributeError as err:
raise TypeError("unable to find default entry point for loopy " raise TypeError("unable to find default entry point for loopy "
"translation unit") "translation unit") from err
def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
@memoize_in(actx, _get_scalar_func_loopy_program)
def get(c_name, nargs, naxes):
from pymbolic.primitives import Subscript, Variable
var_names = [f"i{i}" for i in range(naxes)]
size_names = [f"n{i}" for i in range(naxes)]
subscript = tuple(Variable(vname) for vname in var_names)
from islpy import make_zero_and_vars
v = make_zero_and_vars(var_names, params=size_names)
domain = v[0].domain()
for vname, sname in zip(var_names, size_names, strict=True):
domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])
domain_bset, = domain.get_basic_sets()
import loopy as lp
from arraycontext.transform_metadata import ElementwiseMapKernelTag
def sub(name: str) -> Variable | Subscript:
return Subscript(Variable(name), subscript) if subscript else Variable(name)
return make_loopy_program(
[domain_bset], [
lp.Assignment(
sub("out"),
Variable(c_name)(*[sub(f"inp{i}") for i in range(nargs)]))
], [
lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto)
] + [
lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto)
for i in range(nargs)
] + [...],
name=f"actx_special_{c_name}",
tags=(ElementwiseMapKernelTag(),))
return get(c_name, nargs, naxes)
class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
_numpy_to_c_arc_functions: ClassVar[Mapping[str, str]] = {
"arcsin": "asin",
"arccos": "acos",
"arctan": "atan",
"arctan2": "atan2",
"arcsinh": "asinh",
"arccosh": "acosh",
"arctanh": "atanh",
}
_c_to_numpy_arc_functions: ClassVar[Mapping[str, str]] = {c_name: numpy_name
for numpy_name, c_name in _numpy_to_c_arc_functions.items()}
def __getattr__(self, name):
def loopy_implemented_elwise_func(*args):
if all(np.isscalar(ary) for ary in args):
return getattr(
np, self._c_to_numpy_arc_functions.get(name, name)
)(*args)
actx = self._array_context
prg = _get_scalar_func_loopy_program(actx,
c_name, nargs=len(args), naxes=len(args[0].shape))
outputs = actx.call_loopy(prg,
**{f"inp{i}": arg for i, arg in enumerate(args)})
return outputs["out"]
if name in self._c_to_numpy_arc_functions:
raise RuntimeError(f"'{name}' in ArrayContext.np has been removed. "
f"Use '{self._c_to_numpy_arc_functions[name]}' as in numpy. ")
# normalize to C names anyway
c_name = self._numpy_to_c_arc_functions.get(name, name)
# limit which functions we try to hand off to loopy
if (name in self._numpy_math_functions
or name in self._c_to_numpy_arc_functions):
return multimapped_over_array_containers(loopy_implemented_elwise_func)
else:
raise AttributeError(
f"'{type(self._array_context).__name__}.np' object "
f"has no attribute '{name}'")
# }}} # }}}
......
"""
.. autoclass:: NameHint
"""
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -22,36 +28,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -22,36 +28,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
import sys from dataclasses import dataclass
from pytools.tag import Tag
from warnings import warn
from pytools.tag import UniqueTag
# {{{ deprecation handling
try: @dataclass(frozen=True)
from meshmode.transform_metadata import FirstAxisIsElementsTag \ class NameHint(UniqueTag):
as _FirstAxisIsElementsTag """A tag acting on arrays or array axes. Express that :attr:`name` is a
except ImportError: useful starting point in forming an identifier for the tagged object.
# placeholder in case meshmode is too old to have it.
class _FirstAxisIsElementsTag(Tag): # type: ignore[no-redef]
pass
.. attribute:: name
if sys.version_info >= (3, 7): A string. Must be a valid Python identifier. Not necessarily unique.
def __getattr__(name): """
if name == "FirstAxisIsElementsTag": name: str
warn(f"'arraycontext.{name}' is deprecated. "
f"Use 'meshmode.transform_metadata.{name}' instead. "
f"'arraycontext.{name}' will continue to work until 2022.",
DeprecationWarning, stacklevel=2)
return _FirstAxisIsElementsTag
else:
raise AttributeError(name)
else:
FirstAxisIsElementsTag = _FirstAxisIsElementsTag
# }}} def __post_init__(self):
if not self.name.isidentifier():
raise ValueError("'name' must be an identifier")
# vim: foldmethod=marker # vim: foldmethod=marker
""" """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autoclass:: PytestArrayContextFactory
.. autoclass:: PytestPyOpenCLArrayContextFactory .. autoclass:: PytestPyOpenCLArrayContextFactory
.. autofunction:: pytest_generate_tests_for_array_contexts .. autofunction:: pytest_generate_tests_for_array_contexts
.. autofunction:: pytest_generate_tests_for_pyopencl_array_context
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -31,15 +33,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -31,15 +33,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from typing import Any, Callable, Dict, Sequence, Type, Union from collections.abc import Callable, Sequence
from typing import Any
import pyopencl as cl from arraycontext import NumpyArrayContext
from arraycontext.context import ArrayContext from arraycontext.context import ArrayContext
# {{{ array context factories # {{{ array context factories
class PytestPyOpenCLArrayContextFactory: class PytestArrayContextFactory:
@classmethod
def is_available(cls) -> bool:
return True
def __call__(self) -> ArrayContext:
raise NotImplementedError
class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
""" """
.. automethod:: __init__ .. automethod:: __init__
.. automethod:: __call__ .. automethod:: __call__
...@@ -51,6 +63,14 @@ class PytestPyOpenCLArrayContextFactory: ...@@ -51,6 +63,14 @@ class PytestPyOpenCLArrayContextFactory:
""" """
self.device = device self.device = device
@classmethod
def is_available(cls) -> bool:
try:
import pyopencl # noqa: F401
return True
except ImportError:
return False
def get_command_queue(self): def get_command_queue(self):
# Get rid of leftovers from past tests. # Get rid of leftovers from past tests.
# CL implementations are surprisingly limited in how many # CL implementations are surprisingly limited in how many
...@@ -61,17 +81,25 @@ class PytestPyOpenCLArrayContextFactory: ...@@ -61,17 +81,25 @@ class PytestPyOpenCLArrayContextFactory:
from gc import collect from gc import collect
collect() collect()
import pyopencl as cl
# On Intel CPU CL, existence of a command queue does not ensure that # On Intel CPU CL, existence of a command queue does not ensure that
# the context survives. # the context survives.
ctx = cl.Context([self.device]) ctx = cl.Context([self.device])
return ctx, cl.CommandQueue(ctx) return ctx, cl.CommandQueue(ctx)
def __call__(self) -> ArrayContext:
raise NotImplementedError
class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory): class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
force_device_scalars = True # Deprecated, remove in 2025.
_force_device_scalars = True
@property
def force_device_scalars(self):
from warnings import warn
warn(
"force_device_scalars is deprecated and will be removed in 2025.",
DeprecationWarning, stacklevel=2)
return self._force_device_scalars
@property @property
def actx_class(self): def actx_class(self):
...@@ -84,31 +112,45 @@ class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFact ...@@ -84,31 +112,45 @@ class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFact
# holding a reference to the context to keep it alive in turn. # holding a reference to the context to keep it alive in turn.
# On some implementations (notably Intel CPU), holding a reference # On some implementations (notably Intel CPU), holding a reference
# to a queue does not keep the context alive. # to a queue does not keep the context alive.
ctx, queue = self.get_command_queue() _ctx, queue = self.get_command_queue()
alloc = None
if queue.device.platform.name == "NVIDIA CUDA":
from pyopencl.tools import ImmediateAllocator
alloc = ImmediateAllocator(queue)
from warnings import warn
warn("Disabling SVM due to memory leak "
"in Nvidia CL when running pytest. "
"See https://github.com/inducer/arraycontext/issues/196",
stacklevel=1)
return self.actx_class( return self.actx_class(
queue, queue,
force_device_scalars=self.force_device_scalars) allocator=alloc)
def __str__(self): def __str__(self):
return ("<%s for <pyopencl.Device '%s' on '%s'>>" % return (f"<{self.actx_class.__name__} "
( f"for <pyopencl.Device '{self.device.name.strip()}' "
self.actx_class.__name__, f"on '{self.device.platform.name.strip()}'>>")
self.device.name.strip(),
self.device.platform.name.strip()))
class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
_PytestPyOpenCLArrayContextFactoryWithClass):
force_device_scalars = False
class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
class _PytestPytatoPyOpenCLArrayContextFactory( @classmethod
PytestPyOpenCLArrayContextFactory): def is_available(cls) -> bool:
try:
import pyopencl # noqa: F401
import pytato # noqa: F401
return True
except ImportError:
return False
@property @property
def actx_class(self): def actx_class(self):
from arraycontext import PytatoPyOpenCLArrayContext from arraycontext import PytatoPyOpenCLArrayContext
return PytatoPyOpenCLArrayContext actx_cls = PytatoPyOpenCLArrayContext
return actx_cls
def __call__(self): def __call__(self):
# The ostensibly pointless assignment to *ctx* keeps the CL context alive # The ostensibly pointless assignment to *ctx* keeps the CL context alive
...@@ -116,28 +158,107 @@ class _PytestPytatoPyOpenCLArrayContextFactory( ...@@ -116,28 +158,107 @@ class _PytestPytatoPyOpenCLArrayContextFactory(
# holding a reference to the context to keep it alive in turn. # holding a reference to the context to keep it alive in turn.
# On some implementations (notably Intel CPU), holding a reference # On some implementations (notably Intel CPU), holding a reference
# to a queue does not keep the context alive. # to a queue does not keep the context alive.
ctx, queue = self.get_command_queue() _ctx, queue = self.get_command_queue()
return self.actx_class(queue)
alloc = None
if queue.device.platform.name == "NVIDIA CUDA":
from pyopencl.tools import ImmediateAllocator
alloc = ImmediateAllocator(queue)
from warnings import warn
warn("Disabling SVM due to memory leak "
"in Nvidia CL when running pytest. "
"See https://github.com/inducer/arraycontext/issues/196",
stacklevel=1)
return self.actx_class(queue, allocator=alloc)
def __str__(self): def __str__(self):
return ("<PytatoPyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>>" % return ("<PytatoPyOpenCLArrayContext for "
( f"<pyopencl.Device '{self.device.name.strip()}' "
self.device.name.strip(), f"on '{self.device.platform.name.strip()}'>>")
self.device.platform.name.strip()))
_ARRAY_CONTEXT_FACTORY_REGISTRY: \ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { def __init__(self, *args, **kwargs):
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, pass
"pyopencl-deprecated":
_PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, @classmethod
"pytato-pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, def is_available(cls) -> bool:
} try:
import jax # noqa: F401
return True
except ImportError:
return False
def __call__(self):
from jax import config
from arraycontext import EagerJAXArrayContext
config.update("jax_enable_x64", True)
return EagerJAXArrayContext()
def __str__(self):
return "<EagerJAXArrayContext>"
class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
pass
@classmethod
def is_available(cls) -> bool:
try:
import jax # noqa: F401
import pytato # noqa: F401
return True
except ImportError:
return False
def __call__(self):
from jax import config
from arraycontext import PytatoJAXArrayContext
config.update("jax_enable_x64", True)
return PytatoJAXArrayContext()
def __str__(self):
return "<PytatoJAXArrayContext>"
# {{{ _PytestArrayContextFactory
class _NumpyArrayContextForTests(NumpyArrayContext):
def transform_loopy_program(self, t_unit):
return t_unit
class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
super().__init__()
def __call__(self):
return _NumpyArrayContextForTests()
def __str__(self):
return "<NumpyArrayContext>"
# }}}
_ARRAY_CONTEXT_FACTORY_REGISTRY: dict[str, type[PytestArrayContextFactory]] = {
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
"eagerjax": _PytestEagerJaxArrayContextFactory,
"numpy": _PytestNumpyArrayContextFactory,
}
def register_pytest_array_context_factory( def register_pytest_array_context_factory(
name: str, name: str,
factory: Type[PytestPyOpenCLArrayContextFactory]) -> None: factory: type[PytestArrayContextFactory]) -> None:
if name in _ARRAY_CONTEXT_FACTORY_REGISTRY: if name in _ARRAY_CONTEXT_FACTORY_REGISTRY:
raise ValueError(f"factory '{name}' already exists") raise ValueError(f"factory '{name}' already exists")
...@@ -149,7 +270,7 @@ def register_pytest_array_context_factory( ...@@ -149,7 +270,7 @@ def register_pytest_array_context_factory(
# {{{ pytest integration # {{{ pytest integration
def pytest_generate_tests_for_array_contexts( def pytest_generate_tests_for_array_contexts(
factories: Sequence[Union[str, Type[PytestPyOpenCLArrayContextFactory]]], *, factories: Sequence[str | type[PytestArrayContextFactory]], *,
factory_arg_name: str = "actx_factory", factory_arg_name: str = "actx_factory",
) -> Callable[[Any], None]: ) -> Callable[[Any], None]:
"""Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`. """Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`.
...@@ -166,10 +287,7 @@ def pytest_generate_tests_for_array_contexts( ...@@ -166,10 +287,7 @@ def pytest_generate_tests_for_array_contexts(
"pyopencl", "pyopencl",
]) ])
to use the :mod:`pyopencl`-based array context. For :mod:`pyopencl`-based to use the :mod:`pyopencl`-based array context.
contexts :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl` is used
as a backend, which allows specifying the ``PYOPENCL_TEST`` environment
variable for device selection.
The environment variable ``ARRAYCONTEXT_TEST`` can also be used to The environment variable ``ARRAYCONTEXT_TEST`` can also be used to
overwrite any chosen implementations through *factories*. This is a overwrite any chosen implementations through *factories*. This is a
...@@ -177,11 +295,7 @@ def pytest_generate_tests_for_array_contexts( ...@@ -177,11 +295,7 @@ def pytest_generate_tests_for_array_contexts(
Current supported implementations include: Current supported implementations include:
* ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext` * ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`.
with ``force_device_scalars=True``.
* ``"pyopencl-deprecated"``, which creates a
:class:`~arraycontext.PyOpenCLArrayContext` with
``force_device_scalars=False``.
* ``"pytato-pyopencl"``, which creates a * ``"pytato-pyopencl"``, which creates a
:class:`~arraycontext.PytatoPyOpenCLArrayContext`. :class:`~arraycontext.PytatoPyOpenCLArrayContext`.
...@@ -217,9 +331,19 @@ def pytest_generate_tests_for_array_contexts( ...@@ -217,9 +331,19 @@ def pytest_generate_tests_for_array_contexts(
else: else:
raise ValueError(f"unknown array contexts: {unknown_factories}") raise ValueError(f"unknown array contexts: {unknown_factories}")
unique_factories = set([ available_factories = {
_ARRAY_CONTEXT_FACTORY_REGISTRY.get(factory, factory) # type: ignore[misc] factory for key in unique_factories
for factory in unique_factories]) for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY.get(key, key)]
if (
not isinstance(factory, str)
and issubclass(factory, PytestArrayContextFactory)
and factory.is_available())
}
from pytools import partition
pyopencl_factories, other_factories = partition(
lambda factory: issubclass(factory, PytestPyOpenCLArrayContextFactory),
available_factories)
# }}} # }}}
...@@ -234,6 +358,7 @@ def pytest_generate_tests_for_array_contexts( ...@@ -234,6 +358,7 @@ def pytest_generate_tests_for_array_contexts(
return return
arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values() arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values()
empty_arg_dict = dict.fromkeys(arg_values[0])
# }}} # }}}
...@@ -246,67 +371,34 @@ def pytest_generate_tests_for_array_contexts( ...@@ -246,67 +371,34 @@ def pytest_generate_tests_for_array_contexts(
"'ctx_factory' / 'ctx_getter' as arguments.") "'ctx_factory' / 'ctx_getter' as arguments.")
arg_values_with_actx = [] arg_values_with_actx = []
for arg_dict in arg_values:
if pyopencl_factories:
for arg_dict in arg_values:
arg_values_with_actx.extend([
{factory_arg_name: factory(arg_dict["device"]), **arg_dict}
for factory in pyopencl_factories
])
if other_factories:
arg_values_with_actx.extend([ arg_values_with_actx.extend([
{factory_arg_name: factory(arg_dict["device"]), **arg_dict} {factory_arg_name: factory(), **empty_arg_dict}
for factory in unique_factories for factory in other_factories
]) ])
else: else:
arg_values_with_actx = arg_values arg_values_with_actx = arg_values
arg_value_tuples = [
tuple(arg_dict[name] for name in arg_names)
for arg_dict in arg_values_with_actx
]
# }}} # }}}
# Sort the actx's so that parallel pytest works # NOTE: sorts the args so that parallel pytest works
arg_value_tuples = sorted(arg_value_tuples, key=lambda x: x.__str__()) arg_value_tuples = sorted([
tuple(arg_dict[name] for name in arg_names)
for arg_dict in arg_values_with_actx
], key=lambda x: str(x))
metafunc.parametrize(arg_names, arg_value_tuples, ids=ids) metafunc.parametrize(arg_names, arg_value_tuples, ids=ids)
return inner return inner
def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None:
"""Parametrize tests for pytest to use a
:class:`~arraycontext.PyOpenCLArrayContext`.
Performs device enumeration analogously to
:func:`pyopencl.tools.pytest_generate_tests_for_pyopencl`.
Using the line:
.. code-block:: python
from arraycontext import (
pytest_generate_tests_for_pyopencl_array_context
as pytest_generate_tests)
in your pytest test scripts allows you to use the argument ``actx_factory``,
in your test functions, and they will automatically be
run once for each OpenCL device/platform in the system, as appropriate,
with an argument-less function that returns an
:class:`~arraycontext.ArrayContext` when called.
It also allows you to specify the ``PYOPENCL_TEST`` environment variable
for device selection.
"""
from warnings import warn
warn("pytest_generate_tests_for_pyopencl_array_context is deprecated. "
"Use 'pytest_generate_tests = "
"arraycontext.pytest_generate_tests_for_array_contexts"
"([\"pyopencl-deprecated\"])' instead. "
"pytest_generate_tests_for_pyopencl_array_context will stop working "
"in 2022.",
DeprecationWarning, stacklevel=2)
pytest_generate_tests_for_array_contexts([
"pyopencl-deprecated",
], factory_arg_name="actx_factory")(metafunc)
# }}} # }}}
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
.. autoclass:: CommonSubexpressionTag .. autoclass:: CommonSubexpressionTag
.. autoclass:: ElementwiseMapKernelTag .. autoclass:: ElementwiseMapKernelTag
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
......
VERSION = (2021, 1) from __future__ import annotations
VERSION_TEXT = ".".join(str(i) for i in VERSION)
from importlib import metadata
def _parse_version(version: str) -> tuple[tuple[int, ...], str]:
import re
m = re.match(r"^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT)
assert m is not None
return tuple(int(nr) for nr in m.group(1).split(".")), m.group(2)
VERSION_TEXT = metadata.version("arraycontext")
VERSION, VERSION_STATUS = _parse_version(VERSION_TEXT)
...@@ -4,17 +4,3 @@ The Array Context Abstraction ...@@ -4,17 +4,3 @@ The Array Context Abstraction
.. automodule:: arraycontext .. automodule:: arraycontext
.. automodule:: arraycontext.context .. automodule:: arraycontext.context
Implementations of the Array Context Abstraction
================================================
Array context based on :mod:`pyopencl.array`
--------------------------------------------
.. automodule:: arraycontext.impl.pyopencl
Lazy/Deferred evaluation array context based on :mod:`pytato`
-------------------------------------------------------------
.. automodule:: arraycontext.impl.pytato
from importlib import metadata
from urllib.request import urlopen from urllib.request import urlopen
_conf_url = \ _conf_url = \
"https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py" "https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py"
with urlopen(_conf_url) as _inf: with urlopen(_conf_url) as _inf:
...@@ -7,29 +9,37 @@ with urlopen(_conf_url) as _inf: ...@@ -7,29 +9,37 @@ with urlopen(_conf_url) as _inf:
copyright = "2021, University of Illinois Board of Trustees" copyright = "2021, University of Illinois Board of Trustees"
author = "Arraycontext Contributors" author = "Arraycontext Contributors"
release = metadata.version("arraycontext")
ver_dic = {} version = ".".join(release.split(".")[:2])
exec(compile(open("../arraycontext/version.py").read(), "../arraycontext/version.py",
"exec"), ver_dic)
version = ".".join(str(x) for x in ver_dic["VERSION"])
release = ver_dic["VERSION_TEXT"]
autodoc_type_aliases = {
"DeviceScalar": "arraycontext.DeviceScalar",
"DeviceArray": "arraycontext.DeviceArray",
}
intersphinx_mapping = { intersphinx_mapping = {
"https://docs.python.org/3/": None, "jax": ("https://jax.readthedocs.io/en/latest/", None),
"https://numpy.org/doc/stable/": None, "loopy": ("https://documen.tician.de/loopy", None),
"https://documen.tician.de/pytools": None, "meshmode": ("https://documen.tician.de/meshmode", None),
"https://documen.tician.de/pymbolic": None, "numpy": ("https://numpy.org/doc/stable/", None),
"https://documen.tician.de/pyopencl": None, "pymbolic": ("https://documen.tician.de/pymbolic", None),
"https://documen.tician.de/pytato": None, "pyopencl": ("https://documen.tician.de/pyopencl", None),
"https://documen.tician.de/loopy": None, "pytato": ("https://documen.tician.de/pytato", None),
"https://documen.tician.de/meshmode": None, "pytest": ("https://docs.pytest.org/en/latest/", None),
"https://docs.pytest.org/en/latest/": None, "python": ("https://docs.python.org/3/", None),
"pytools": ("https://documen.tician.de/pytools", None),
} }
# Some modules need to import things just so that sphinx can resolve symbols in
# type annotations. Often, we do not want these imports (e.g. of PyOpenCL) when
# in normal use (because they would introduce unintended side effects or hard
# dependencies). This flag exists so that these imports only occur during doc
# build. Since sphinx appears to resolve type hints lexically (as it should),
# this needs to be cross-module (since, e.g. an inherited arraycontext
# docstring can be read by sphinx when building meshmode, a dependent package),
# this needs a setting of the same name across all packages involved, that's
# why this name is as global-sounding as it is.
import sys import sys
sys.ARRAYCONTEXT_BUILDING_SPHINX_DOCS = True
sys._BUILDING_SPHINX_DOCS = True
nitpick_ignore_regex = [
["py:class", r"arraycontext\.context\.ContainerOrScalarT"],
]
Implementations of the Array Context Abstraction
================================================
..
When adding a new array context here, make sure to also add it to and run
```
doc/make_numpy_coverage_table.py
```
to update the coverage table below!
Array context based on :mod:`numpy`
--------------------------------------------
.. automodule:: arraycontext.impl.numpy
Array context based on :mod:`pyopencl.array`
--------------------------------------------
.. automodule:: arraycontext.impl.pyopencl
Lazy/Deferred evaluation array context based on :mod:`pytato`
-------------------------------------------------------------
.. automodule:: arraycontext.impl.pytato
Array context based on :mod:`jax.numpy`
---------------------------------------
.. automodule:: arraycontext.impl.jax
.. _numpy-coverage:
:mod:`numpy` coverage
---------------------
This is a list of functionality implemented by :attr:`arraycontext.ArrayContext.np`.
.. note::
Only functions and methods that have at least one implementation are listed.
.. include:: numpy_coverage.rst
...@@ -7,6 +7,7 @@ implementations for: ...@@ -7,6 +7,7 @@ implementations for:
- :mod:`numpy` - :mod:`numpy`
- :mod:`pyopencl` - :mod:`pyopencl`
- :mod:`jax.numpy`
- :mod:`pytato` (for lazy/deferred evaluation) - :mod:`pytato` (for lazy/deferred evaluation)
- Debugging - Debugging
- Profiling - Profiling
...@@ -14,11 +15,45 @@ implementations for: ...@@ -14,11 +15,45 @@ implementations for:
:mod:`arraycontext` started life as an array abstraction for use with the :mod:`arraycontext` started life as an array abstraction for use with the
:mod:`meshmode` unstrucuted discretization package. :mod:`meshmode` unstrucuted discretization package.
Design Guidelines
-----------------
Here are some of the guidelines we aim to follow in :mod:`arraycontext`. There
exist numerous other, related efforts, such as the `Python array API standard
<https://data-apis.org/array-api/latest/purpose_and_scope.html>`__. These
points may aid in clarifying and differentiating our objectives.
- The array context is about exposing the common subset of operations
available in immutable and mutable arrays. As a result, the interface
does *not* seek to support interfaces that provide, enable, or are typically
used only with in-place mutation.
For example: The equivalents of :func:`numpy.empty` were deprecated
and will eventually be removed.
- Each array context offers a specific subset of of :mod:`numpy` under
:attr:`arraycontext.ArrayContext.np`. Functions under this namespace
must be unconditionally :mod:`numpy`-compatible, that is, they may not
offer an interface beyond what numpy offers. Functions that are
incompatible, for example by supporting tag metadata
(cf. :meth:`arraycontext.ArrayContext.einsum`) should live under the
:class:`~arraycontext.ArrayContext` directly.
- Similarly, we strive to minimize redundancy between attributes of
:class:`~arraycontext.ArrayContext` and :attr:`arraycontext.ArrayContext.np`.
For example: ``ArrayContext.empty_like`` was deprecated.
- Array containers are data structures that may contain arrays.
See :mod:`arraycontext.container`. We strive to support these, where sensible,
in :class:`~arraycontext.ArrayContext` and :attr:`arraycontext.ArrayContext.np`.
Contents Contents
-------- --------
.. toctree:: .. toctree::
array_context array_context
implementations
container container
other other
misc misc
......
"""
Workflow:
1. If a new array context is implemented, it should be added to
:func:`initialize_contexts`.
2. If a new function is implemented, it should be added to the
corresponding ``write_section_name`` function.
3. Once everything is added, regenerate the tables using
.. code::
python make_numpy_support_table.py numpy_coverage.rst
"""
from __future__ import annotations
import pathlib
from mako.template import Template
import arraycontext
# {{{ templating
HEADER = """
.. raw:: html
<style> .red {color:red} </style>
<style> .green {color:green} </style>
.. role:: red
.. role:: green
"""
TABLE_TEMPLATE = Template("""
${title}
${'~' * len(title)}
.. list-table::
:header-rows: 1
* - Function
% for ctx in contexts:
- :class:`~arraycontext.${type(ctx).__name__}`
% endfor
% for name, (directive, in_context) in numpy_functions_for_context.items():
* - :${directive}:`numpy.${name}`
% for ctx in contexts:
<%
flag = in_context.get(type(ctx), "yes").capitalize()
color = "green" if flag == "Yes" else "red"
%> - :${color}:`${flag}`
% endfor
% endfor
""")
def initialize_contexts():
import pyopencl as cl
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
return [
arraycontext.PyOpenCLArrayContext(queue, force_device_scalars=True),
arraycontext.EagerJAXArrayContext(),
arraycontext.PytatoPyOpenCLArrayContext(queue),
arraycontext.PytatoJAXArrayContext(),
]
def build_supported_functions(funcs, contexts):
import numpy as np
numpy_functions_for_context = {}
for directive, name in funcs:
if not hasattr(np, name):
raise ValueError(f"'{name}' not found in numpy namespace")
in_context = {}
for ctx in contexts:
try:
_ = getattr(ctx.np, name)
except AttributeError:
in_context[type(ctx)] = "No"
numpy_functions_for_context[name] = (directive, in_context)
return numpy_functions_for_context
# }}}
# {{{ writing
def write_array_creation_routines(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.array-creation.html
funcs = (
# (sphinx-directive, name)
("func", "empty_like"),
("func", "ones_like"),
("func", "zeros_like"),
("func", "full_like"),
("func", "copy"),
)
r = TABLE_TEMPLATE.render(
title="Array creation routines",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_array_manipulation_routines(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.array-manipulation.html
funcs = (
# (sphinx-directive, name)
("func", "reshape"),
("func", "ravel"),
("func", "transpose"),
("func", "broadcast_to"),
("func", "concatenate"),
("func", "stack"),
)
r = TABLE_TEMPLATE.render(
title="Array manipulation routines",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_linear_algebra(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.linalg.html
funcs = (
# (sphinx-directive, name)
("func", "vdot"),
)
r = TABLE_TEMPLATE.render(
title="Linear algebra",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_logic_functions(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.logic.html
funcs = (
# (sphinx-directive, name)
("func", "all"),
("func", "any"),
("data", "greater"),
("data", "greater_equal"),
("data", "less"),
("data", "less_equal"),
("data", "equal"),
("data", "not_equal"),
)
r = TABLE_TEMPLATE.render(
title="Logic Functions",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_mathematical_functions(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.math.html
funcs = (
("data", "sin"),
("data", "cos"),
("data", "tan"),
("data", "arcsin"),
("data", "arccos"),
("data", "arctan"),
("data", "arctan2"),
("data", "sinh"),
("data", "cosh"),
("data", "tanh"),
("data", "floor"),
("data", "ceil"),
("func", "sum"),
("data", "exp"),
("data", "log"),
("data", "log10"),
("func", "real"),
("func", "imag"),
("data", "conjugate"),
("data", "maximum"),
("func", "amax"),
("data", "minimum"),
("func", "amin"),
("data", "sqrt"),
("data", "absolute"),
("data", "fabs"),
)
r = TABLE_TEMPLATE.render(
title="Mathematical functions",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_searching_sorting_and_counting(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.sort.html
funcs = (
("func", "where"),
)
r = TABLE_TEMPLATE.render(
title="Sorting, searching, and counting",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
# }}}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("filename", nargs="?", type=pathlib.Path, default=None)
args = parser.parse_args()
def write(outf):
outf.write(HEADER)
write_array_creation_routines(outf, ctxs)
write_array_manipulation_routines(outf, ctxs)
write_linear_algebra(outf, ctxs)
write_logic_functions(outf, ctxs)
write_mathematical_functions(outf, ctxs)
ctxs = initialize_contexts()
if args.filename:
with open(args.filename, "w") as outf:
write(outf)
else:
import sys
write(sys.stdout)
.. raw:: html
<style> .red {color:red} </style>
<style> .green {color:green} </style>
.. role:: red
.. role:: green
Array creation routines
~~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.empty_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.ones_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.zeros_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.full_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.copy`
- :green:`Yes`
- :green:`Yes`
- :red:`No`
- :red:`No`
Array manipulation routines
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.reshape`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.ravel`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.transpose`
- :red:`No`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.broadcast_to`
- :red:`No`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.concatenate`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.stack`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
Linear algebra
~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.vdot`
- :green:`Yes`
- :green:`Yes`
- :red:`No`
- :red:`No`
Logic Functions
~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.all`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.any`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.greater`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.greater_equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.less`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.less_equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.not_equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
Mathematical functions
~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :data:`numpy.sin`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.cos`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.tan`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arcsin`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arccos`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arctan`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arctan2`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.sinh`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.cosh`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.tanh`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.floor`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.ceil`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.sum`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.exp`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.log`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.log10`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.real`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.imag`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.conjugate`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.maximum`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.amax`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.minimum`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.amin`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.sqrt`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.absolute`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.fabs`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
Other functionality Other functionality
=================== ===================
.. _metadata:
Metadata ("tags") for Arrays and Array Axes
-------------------------------------------
.. automodule:: arraycontext.metadata
:class:`~arraycontext.ArrayContext`-generating fixture for :mod:`pytest` :class:`~arraycontext.ArrayContext`-generating fixture for :mod:`pytest`
------------------------------------------------------------------------ ------------------------------------------------------------------------
......
#! /bin/sh #! /bin/sh
rsync --verbose --archive --delete _build/html/* doc-upload:doc/arraycontext rsync --verbose --archive --delete _build/html/ doc-upload:doc/arraycontext
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "arraycontext"
version = "2024.0"
description = "Choose your favorite numpy-workalike"
readme = "README.rst"
license = { text = "MIT" }
authors = [
{ name = "Andreas Kloeckner", email = "inform@tiker.net" },
]
requires-python = ">=3.10"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Other Audience",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Software Development :: Libraries",
"Topic :: Utilities",
]
dependencies = [
"immutabledict>=4.1",
"numpy",
"pytools>=2024.1.3",
# for Self
"typing_extensions>=4",
]
[project.optional-dependencies]
jax = [
"jax>=0.4",
]
pyopencl = [
"islpy>=2024.1",
"loopy>=2024.1",
"pyopencl>=2024.1",
]
pytato = [
"pytato>=2021.1",
]
test = [
"mypy",
"pytest",
"ruff",
]
[project.urls]
Documentation = "https://documen.tician.de/arraycontext"
Homepage = "https://github.com/inducer/arraycontext"
[tool.ruff]
preview = true
[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle
"F", # pyflakes
"G", # flake8-logging-format
"I", # flake8-isort
"N", # pep8-naming
"NPY", # numpy
"Q", # flake8-quotes
"RUF", # ruff
"UP", # pyupgrade
"W", # pycodestyle
"SIM",
]
extend-ignore = [
"C90", # McCabe complexity
"E221", # multiple spaces before operator
"E226", # missing whitespace around arithmetic operator
"E402", # module-level import not at top of file
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
inline-quotes = "double"
multiline-quotes = "double"
[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = [
"jax",
"loopy",
"pymbolic",
"pyopencl",
"pytato",
"pytools",
]
known-local-folder = [
"arraycontext",
"testlib",
]
lines-after-imports = 2
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.per-file-ignores]
"doc/conf.py" = ["I002"]
# To avoid a requirement of array container definitions being someplace importable
# from @dataclass_array_container.
"test/test_utils.py" = ["I002"]
[tool.mypy]
python_version = "3.10"
warn_unused_ignores = true
# TODO: enable this
# check_untyped_defs = true
[[tool.mypy.overrides]]
module = [
"islpy.*",
"loopy.*",
"meshmode.*",
"pymbolic",
"pymbolic.*",
"pyopencl.*",
"jax.*",
]
ignore_missing_imports = true
[tool.typos.default]
extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"
]