Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • inducer/arraycontext
  • kaushikcfd/arraycontext
  • fikl2/arraycontext
3 results
Show changes
Showing
with 3397 additions and 571 deletions
......@@ -2,6 +2,9 @@
.. currentmodule:: arraycontext
.. autoclass:: PyOpenCLArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
......@@ -26,15 +29,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce
import operator
from functools import partial, reduce
import numpy as np
from arraycontext.fake_numpy import \
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import (
rec_multimap_array_container, rec_map_array_container,
rec_map_reduce_array_container,
)
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
rec_multimap_reduce_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
try:
import pyopencl as cl # noqa: F401
......@@ -45,114 +56,82 @@ except ImportError:
# {{{ fake numpy
class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
def _get_fake_numpy_linalg_namespace(self):
return _PyOpenCLFakeNumpyLinalgNamespace(self._array_context)
# {{{ comparisons
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# FIXME: This should be documentation, not a comment.
# These are here mainly because some arrays may choose to interpret
# equality comparison as a binary predicate of structural identity,
# i.e. more like "are you two equal", and not like numpy semantics.
# These operations provide access to numpy-style comparisons in that
# case.
# {{{ array creation routines
def equal(self, x, y):
return rec_multimap_array_container(operator.eq, x, y)
def zeros(self, shape, dtype) -> TaggableCLArray:
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.zeros(self._array_context.queue, shape, dtype,
allocator=self._array_context.allocator)
def not_equal(self, x, y):
return rec_multimap_array_container(operator.ne, x, y)
def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
"deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
"instead.",
DeprecationWarning, stacklevel=2)
def greater(self, x, y):
return rec_multimap_array_container(operator.gt, x, y)
import arraycontext.impl.pyopencl.taggable_cl_array as tga
actx = self._array_context
def greater_equal(self, x, y):
return rec_multimap_array_container(operator.ge, x, y)
def _empty_like(array):
return tga.empty(actx.queue, array.shape, array.dtype,
allocator=actx.allocator, axes=array.axes, tags=array.tags)
def less(self, x, y):
return rec_multimap_array_container(operator.lt, x, y)
return actx._rec_map_container(_empty_like, ary)
def less_equal(self, x, y):
return rec_multimap_array_container(operator.le, x, y)
def zeros_like(self, ary):
import arraycontext.impl.pyopencl.taggable_cl_array as tga
actx = self._array_context
# }}}
def _zeros_like(array):
return tga.zeros(
actx.queue, array.shape, array.dtype,
allocator=actx.allocator, axes=array.axes, tags=array.tags)
return actx._rec_map_container(_zeros_like, ary, default_scalar=0)
def ones_like(self, ary):
def _ones_like(subary):
ones = self._array_context.empty_like(subary)
ones.fill(1)
return ones
return self.full_like(ary, 1)
return self._new_like(ary, _ones_like)
def full_like(self, ary, fill_value):
import arraycontext.impl.pyopencl.taggable_cl_array as tga
actx = self._array_context
def maximum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.maximum, queue=self._array_context.queue),
x, y)
def _full_like(subary):
filled = tga.empty(
actx.queue, subary.shape, subary.dtype,
allocator=actx.allocator, axes=subary.axes, tags=subary.tags)
filled.fill(fill_value)
def minimum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.minimum, queue=self._array_context.queue),
x, y)
return filled
def where(self, criterion, then, else_):
def where_inner(inner_crit, inner_then, inner_else):
if isinstance(inner_crit, bool):
return inner_then if inner_crit else inner_else
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
queue=self._array_context.queue)
return actx._rec_map_container(_full_like, ary, default_scalar=fill_value)
return rec_multimap_array_container(where_inner, criterion, then, else_)
def copy(self, ary):
def _copy(subary):
return subary.copy(queue=self._array_context.queue)
def sum(self, a, dtype=None):
result = rec_map_reduce_array_container(
sum,
partial(cl_array.sum, dtype=dtype, queue=self._array_context.queue),
a)
return self._array_context._rec_map_container(_copy, ary)
if not self._array_context._force_device_scalars:
result = result.get()[()]
return result
def arange(self, *args, **kwargs):
return cl_array.arange(self._array_context.queue, *args, **kwargs)
def min(self, a):
queue = self._array_context.queue
result = rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
partial(cl_array.min, queue=queue),
a)
# }}}
if not self._array_context._force_device_scalars:
result = result.get()[()]
return result
# {{{ array manipulation routines
def max(self, a):
queue = self._array_context.queue
result = rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
partial(cl_array.max, queue=queue),
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order),
a)
if not self._array_context._force_device_scalars:
result = result.get()[()]
return result
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: cl_array.stack(arrays=args, axis=axis,
queue=self._array_context.queue),
*arrays)
def reshape(self, a, newshape):
return cl_array.reshape(a, newshape)
def concatenate(self, arrays, axis=0):
return cl_array.concatenate(
arrays, axis,
self._array_context.queue,
self._array_context.allocator
)
def ravel(self, a, order="C"):
def _rec_ravel(a):
if order in "FC":
......@@ -175,16 +154,209 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
return rec_map_array_container(_rec_ravel, a)
def concatenate(self, arrays, axis=0):
return cl_array.concatenate(
arrays, axis,
self._array_context.queue,
self._array_context.allocator
)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: cl_array.stack(arrays=args, axis=axis,
queue=self._array_context.queue),
*arrays)
# }}}
# {{{ linear algebra
def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
result = rec_multimap_reduce_array_container(
return rec_multimap_reduce_array_container(
sum,
partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
x, y)
if not self._array_context._force_device_scalars:
result = result.get()[()]
return result
# }}}
# {{{ logic functions
def all(self, a):
queue = self._array_context.queue
def _all(ary):
if np.isscalar(ary):
return np.int8(all([ary]))
return ary.all(queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
_all,
a)
def any(self, a):
queue = self._array_context.queue
def _any(ary):
if np.isscalar(ary):
return np.int8(any([ary]))
return ary.any(queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_any,
a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context
queue = actx.queue
# NOTE: pyopencl doesn't like `bool` much, so use `int8` instead
true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
if type(x) is not type(y):
return false_ary
try:
serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError:
assert isinstance(x, cl_array.Array)
assert isinstance(y, cl_array.Array)
if x.shape != y.shape:
return false_ary
else:
return (x == y).all()
else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce(
partial(cl_array.minimum, queue=queue),
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return rec_equal(a, b)
# FIXME: This should be documentation, not a comment.
# These are here mainly because some arrays may choose to interpret
# equality comparison as a binary predicate of structural identity,
# i.e. more like "are you two equal", and not like numpy semantics.
# These operations provide access to numpy-style comparisons in that
# case.
def greater(self, x, y):
return rec_multimap_array_container(operator.gt, x, y)
def greater_equal(self, x, y):
return rec_multimap_array_container(operator.ge, x, y)
def less(self, x, y):
return rec_multimap_array_container(operator.lt, x, y)
def less_equal(self, x, y):
return rec_multimap_array_container(operator.le, x, y)
def equal(self, x, y):
return rec_multimap_array_container(operator.eq, x, y)
def not_equal(self, x, y):
return rec_multimap_array_container(operator.ne, x, y)
def logical_or(self, x, y):
return rec_multimap_array_container(cl_array.logical_or, x, y)
def logical_and(self, x, y):
return rec_multimap_array_container(cl_array.logical_and, x, y)
def logical_not(self, x):
return rec_map_array_container(cl_array.logical_not, x)
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
if isinstance(axis, int):
axis = axis,
def _rec_sum(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Sum over '{axis}' axes not supported.")
return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)
return rec_map_reduce_array_container(sum, _rec_sum, a)
def maximum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.maximum, queue=self._array_context.queue),
x, y)
def amax(self, a, axis=None):
queue = self._array_context.queue
if isinstance(axis, int):
axis = axis,
def _rec_max(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
return cl_array.max(ary, queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_rec_max,
a)
max = amax
def minimum(self, x, y):
return rec_multimap_array_container(
partial(cl_array.minimum, queue=self._array_context.queue),
x, y)
def amin(self, a, axis=None):
queue = self._array_context.queue
if isinstance(axis, int):
axis = axis,
def _rec_min(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
return cl_array.min(ary, queue=queue)
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
_rec_min,
a)
min = amin
def absolute(self, a):
return self.abs(a)
# }}}
# {{{ sorting, searching, and counting
def where(self, criterion, then, else_):
def where_inner(inner_crit, inner_then, inner_else):
if isinstance(inner_crit, bool | np.bool_):
return inner_then if inner_crit else inner_else
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
queue=self._array_context.queue)
return rec_multimap_array_container(where_inner, criterion, then, else_)
# }}}
# }}}
......
"""
.. autoclass:: TaggableCLArray
.. autoclass:: Axis
.. autofunction:: to_tagged_cl_array
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import numpy as np
import pyopencl.array as cla
from pytools import memoize
from pytools.tag import Tag, Taggable, ToTagSetConvertible
# {{{ utils
@dataclass(frozen=True, eq=True)
class Axis(Taggable):
"""
Records the tags corresponding to a dimension of :class:`TaggableCLArray`.
"""
tags: frozenset[Tag]
def _with_new_tags(self, tags: frozenset[Tag]) -> Axis:
from dataclasses import replace
return replace(self, tags=tags)
@memoize
def _construct_untagged_axes(ndim: int) -> tuple[Axis, ...]:
return tuple(Axis(frozenset()) for _ in range(ndim))
def _unwrap_cl_array(ary: cla.Array) -> dict[str, Any]:
return {
"shape": ary.shape,
"dtype": ary.dtype,
"allocator": ary.allocator,
"strides": ary.strides,
"data": ary.base_data,
"offset": ary.offset,
"events": ary.events,
"_context": ary.context,
"_queue": ary.queue,
"_size": ary.size,
"_fast": True,
}
# }}}
# {{{ TaggableCLArray
class TaggableCLArray(cla.Array, Taggable):
"""
A :class:`pyopencl.array.Array` with additional metadata. This is used by
:class:`~arraycontext.PytatoPyOpenCLArrayContext` to preserve tags for data
while frozen, and also in a similar capacity by
:class:`~arraycontext.PyOpenCLArrayContext`.
.. attribute:: axes
A :class:`tuple` of instances of :class:`Axis`, with one :class:`Axis`
for each dimension of the array.
.. attribute:: tags
A :class:`frozenset` of :class:`pytools.tag.Tag`. Typically intended to
record application-specific metadata to drive the optimizations in
:meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
"""
def __init__(self, cq, shape, dtype, order="C", allocator=None,
data=None, offset=0, strides=None, events=None, _flags=None,
_fast=False, _size=None, _context=None, _queue=None,
axes=None, tags=frozenset()):
super().__init__(cq=cq, shape=shape, dtype=dtype,
order=order, allocator=allocator,
data=data, offset=offset,
strides=strides, events=events,
_flags=_flags, _fast=_fast,
_size=_size, _context=_context,
_queue=_queue)
if __debug__:
if not isinstance(tags, frozenset):
raise TypeError("tags are not a frozenset")
if axes is not None and len(axes) != self.ndim:
raise ValueError("axes length does not match array dimension: "
f"got {len(axes)} axes for {self.ndim}d array")
if axes is None:
axes = _construct_untagged_axes(self.ndim)
self.tags = tags
self.axes = axes
def __repr__(self) -> str:
return (f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype}, "
f"tags={self.tags}, axes={self.axes})")
def copy(self, queue=cla._copy_queue):
ary = super().copy(queue=queue)
return type(self)(None, tags=self.tags, axes=self.axes,
**_unwrap_cl_array(ary))
def _with_new_tags(self, tags: frozenset[Tag]) -> TaggableCLArray:
return type(self)(None, tags=tags, axes=self.axes,
**_unwrap_cl_array(self))
def with_tagged_axis(self, iaxis: int,
tags: ToTagSetConvertible) -> TaggableCLArray:
"""
Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
"""
new_axes = (self.axes[:iaxis]
+ (self.axes[iaxis].tagged(tags),)
+ self.axes[iaxis+1:])
return type(self)(None, tags=self.tags, axes=new_axes,
**_unwrap_cl_array(self))
def to_tagged_cl_array(ary: cla.Array,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset()) -> TaggableCLArray:
"""
Returns a :class:`TaggableCLArray` that is constructed from the data in
*ary* along with the metadata from *axes* and *tags*. If *ary* is already a
:class:`TaggableCLArray`, the new *tags* and *axes* are added to the
existing ones.
:arg axes: An instance of :class:`Axis` for each dimension of the
array. If passed *None*, then initialized to a :class:`pytato.Axis`
with no tags attached for each dimension.
"""
if axes is not None and len(axes) != ary.ndim:
raise ValueError("axes length does not match array dimension: "
f"got {len(axes)} axes for {ary.ndim}d array")
from pytools.tag import normalize_tags
tags = normalize_tags(tags)
if isinstance(ary, TaggableCLArray):
if axes is not None:
for i, axis in enumerate(axes):
ary = ary.with_tagged_axis(i, axis.tags)
if tags:
ary = ary.tagged(tags)
return ary
elif isinstance(ary, cla.Array):
return TaggableCLArray(None, tags=tags, axes=axes,
**_unwrap_cl_array(ary))
else:
raise TypeError(f"unsupported array type: '{type(ary).__name__}'")
# }}}
# {{{ creation
def empty(queue, shape, dtype=float, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
if dtype is not None:
dtype = np.dtype(dtype)
return TaggableCLArray(
queue, shape, dtype,
axes=axes, tags=tags,
order=order, allocator=allocator)
def zeros(queue, shape, dtype=float, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
result = empty(
queue, shape, dtype=dtype, axes=axes, tags=tags,
order=order, allocator=allocator)
result._zero_fill()
return result
def to_device(queue, ary, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
allocator=None):
return to_tagged_cl_array(
cla.to_device(queue, ary, allocator=allocator),
axes=axes, tags=tags)
# }}}
"""
from __future__ import annotations
__doc__ = """
.. currentmodule:: arraycontext
A :mod:`pytato`-based array context defers the evaluation of an array until its
A :mod:`pytato`-based array context defers the evaluation of an array until it is
frozen. The execution contexts for the evaluations are specific to an
:class:`~arraycontext.ArrayContext` type. For ex.
:class:`~arraycontext.ArrayContext` type. For example,
:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to
JIT-compile and execute the array expressions.
Following :mod:`pytato`-based array context are provided:
The following :mod:`pytato`-based array contexts are provided:
.. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext
Compiling a python callable
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Compiling a Python callable (Internal)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: arraycontext.impl.pytato.compile
Utils
^^^^^
.. automodule:: arraycontext.impl.pytato.utils
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
......@@ -41,15 +51,193 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from arraycontext.context import ArrayContext
import abc
import sys
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import numpy as np
from typing import Any, Callable, Union, Sequence
from pytools.tag import Tag
from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)
from arraycontext.metadata import NameHint
if TYPE_CHECKING:
import loopy as lp
import pyopencl as cl
import pytato
if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import pyopencl as cl
import logging
logger = logging.getLogger(__name__)
# {{{ tag conversion
def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
tags = normalize_tags(tags)
name_hints = [tag for tag in tags if isinstance(tag, NameHint)]
if name_hints:
name_hint, = name_hints
from pytato.tags import PrefixNamed
prefix_nameds = [tag for tag in tags if isinstance(tag, PrefixNamed)]
if prefix_nameds:
prefix_named, = prefix_nameds
from warnings import warn
warn("When converting a "
f"arraycontext.metadata.NameHint('{name_hint.name}') "
"to pytato.tags.PrefixNamed, "
f"PrefixNamed('{prefix_named.prefix}') "
"was already present.", stacklevel=1)
tags = (
(tags | frozenset({PrefixNamed(name_hint.name)}))
- {name_hint})
return tags
# }}}
class _NotOnlyDataWrappers(Exception): # noqa: N818
pass
# {{{ _BasePytatoArrayContext
class _BasePytatoArrayContext(ArrayContext, abc.ABC):
"""
An abstract :class:`ArrayContext` that uses :mod:`pytato` data types to
represent.
.. automethod:: __init__
.. automethod:: transform_dag
.. automethod:: compile
"""
def __init__(
self, *,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
super().__init__()
import pytato as pt
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
self._dag_transform_cache: dict[
pt.DictOfNamedArrays,
tuple[pt.DictOfNamedArrays, str]] = {}
if compile_trace_callback is None:
def _compile_trace_callback(what, stage, ir):
pass
compile_trace_callback = _compile_trace_callback
self._compile_trace_callback = compile_trace_callback
def _get_fake_numpy_namespace(self):
from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
return PytatoFakeNumpyNamespace(self)
@abc.abstractproperty
def _frozen_array_types(self) -> tuple[type, ...]:
"""
Returns valid frozen array types for the array context.
"""
# {{{ compilation
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
*dag* (most likely to perform domain-specific optimizations). Every
:mod:`pytato` DAG that is compiled to a GPU-kernel is
passed through this routine.
:arg dag: An instance of :class:`pytato.DictOfNamedArrays`
:returns: A transformed version of *dag*.
"""
return dag
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)
return t_unit
@abc.abstractmethod
def einsum(self, spec, *args, arg_names=None, tagged=()):
pass
# }}}
# {{{ properties
@property
def permits_inplace_modification(self):
return False
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True
def get_target(self):
return None
# }}}
# }}}
# {{{ PytatoPyOpenCLArrayContext
class PytatoPyOpenCLArrayContext(ArrayContext):
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
the arrays targeting OpenCL for offloading operations.
.. attribute:: queue
......@@ -62,125 +250,687 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
to use the default allocator.
.. automethod:: __init__
.. automethod:: transform_dag
.. automethod:: compile
"""
def __init__(
self, queue: cl.CommandQueue, allocator=None, *,
use_memory_pool: bool | None = None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
# do not use: only for testing
_force_svm_arg_limit: int | None = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
if allocator is not None and use_memory_pool is not None:
raise TypeError("may not specify both allocator and use_memory_pool")
self.using_svm = None
if allocator is None:
from pyopencl.characterize import has_coarse_grain_buffer_svm
has_svm = has_coarse_grain_buffer_svm(queue.device)
if has_svm:
self.using_svm = True
from pyopencl.tools import SVMAllocator
allocator = SVMAllocator(queue.context, queue=queue)
if use_memory_pool:
from pyopencl.tools import SVMPool
allocator = SVMPool(allocator)
else:
self.using_svm = False
from pyopencl.tools import ImmediateAllocator
allocator = ImmediateAllocator(queue)
if use_memory_pool:
from pyopencl.tools import MemoryPool
allocator = MemoryPool(allocator)
else:
# Check whether the passed allocator allocates SVM
try:
from pyopencl import SVMPointer
mem = allocator(4)
if isinstance(mem, SVMPointer):
self.using_svm = True
else:
self.using_svm = False
except ImportError:
self.using_svm = False
def __init__(self, queue, allocator=None):
import pyopencl.array as cla
import pytato as pt
super().__init__()
super().__init__(compile_trace_callback=compile_trace_callback)
self.queue = queue
self.allocator = allocator
self.array_types = (pt.Array, )
self.array_types = (pt.Array, cla.Array)
# unused, but necessary to keep the context alive
self.context = self.queue.context
def _get_fake_numpy_namespace(self):
from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace
return PytatoFakeNumpyNamespace(self)
self._force_svm_arg_limit = _force_svm_arg_limit
@property
def _frozen_array_types(self) -> tuple[type, ...]:
import pyopencl.array as cla
return (cla.Array,)
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
import pytato as pt
import arraycontext.impl.pyopencl.taggable_cl_array as tga
if allowed_types is None:
allowed_types = (pt.Array, tga.TaggableCLArray)
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{func.__qualname__} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface
def clone(self):
return type(self)(self.queue, self.allocator)
def from_numpy(self, array):
import pytato as pt
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _from_numpy(ary):
return pt.make_data_wrapper(
tga.to_device(self.queue, ary, allocator=self.allocator)
)
return with_array_context(
self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True),
actx=self)
def to_numpy(self, array):
def _to_numpy(ary):
return ary.get(queue=self.queue)
return with_array_context(
self._rec_map_container(_to_numpy, self.freeze(array)),
actx=None)
@memoize_method
def get_target(self):
import pyopencl as cl
import pyopencl.characterize as cl_char
dev = self.queue.device
if (
self._force_svm_arg_limit is not None
or (
self.using_svm and dev.type & cl.device_type.GPU
and cl_char.has_coarse_grain_buffer_svm(dev))):
if dev.max_parameter_size == 4352:
# Nvidia devices and PTXAS declare a limit of 4352 bytes,
# which is incorrect. The CUDA documentation at
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#function-parameters
# mentions a limit of 4KB, which is also incorrect.
# As far as I can tell, the actual limit is around 4080
# bytes, at least on a K40. Reducing the limit further
# in order to be on the safe side.
# Note that the naming convention isn't super consistent
# for Nvidia GPUs, so that we only use the maximum
# parameter size to determine if it is an Nvidia GPU.
limit = 4096-200
from warnings import warn
warn("Running on an Nvidia GPU, reducing the argument "
f"size limit from 4352 to {limit}.", stacklevel=1)
else:
limit = dev.max_parameter_size
if self._force_svm_arg_limit is not None:
limit = self._force_svm_arg_limit
logger.info(
"limiting argument buffer size for %s to %d bytes",
dev, limit)
def empty(self, shape, dtype):
raise ValueError("PytatoPyOpenCLArrayContext does not support empty")
from arraycontext.impl.pytato.utils import (
ArgSizeLimitingPytatoLoopyPyOpenCLTarget,
)
return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
else:
return super().get_target()
def zeros(self, shape, dtype):
def freeze(self, array):
if np.isscalar(array):
return array
import pyopencl.array as cla
import pytato as pt
return pt.zeros(shape, dtype)
def from_numpy(self, np_array: np.ndarray):
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pyopencl.taggable_cl_array import (
TaggableCLArray,
to_tagged_cl_array,
)
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
from arraycontext.impl.pytato.utils import (
_normalize_pt_expr,
get_cl_axes_from_pt_axes,
)
array_as_dict: dict[str, cla.Array | TaggableCLArray | pt.Array] = {}
key_to_frozen_subary: dict[str, TaggableCLArray] = {}
key_to_pt_arrays: dict[str, pt.Array] = {}
def _record_leaf_ary_in_dict(
key: tuple[Any, ...],
ary: cla.Array | TaggableCLArray | pt.Array) -> None:
key_str = "_ary" + _ary_container_key_stringifier(key)
array_as_dict[key_str] = ary
rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)
# {{{ remove any non pytato arrays from array_as_dict
for key, subary in array_as_dict.items():
if isinstance(subary, TaggableCLArray):
key_to_frozen_subary[key] = subary.with_queue(None)
elif isinstance(subary, self._frozen_array_types):
from warnings import warn
warn(f"Invoking {type(self).__name__}.freeze with"
f" {type(subary).__name__} will be unsupported in 2023. Use"
" `to_tagged_cl_array` to convert instances to TaggableCLArray.",
DeprecationWarning, stacklevel=2)
key_to_frozen_subary[key] = (
to_tagged_cl_array(subary.with_queue(None)))
elif isinstance(subary, pt.DataWrapper):
# trivial freeze.
key_to_frozen_subary[key] = to_tagged_cl_array(
subary.data,
axes=get_cl_axes_from_pt_axes(subary.axes),
tags=subary.tags)
elif isinstance(subary, pt.Array):
# Don't be tempted to take shortcuts here, e.g. for empty
# arrays, as this will inhibit metadata propagation that
# may happen in transform_dag below. See
# https://github.com/inducer/arraycontext/pull/167#issuecomment-1151877480
key_to_pt_arrays[key] = subary
else:
raise TypeError(
f"{type(self).__name__}.freeze invoked with an unsupported "
f"array type: got '{type(subary).__name__}', but expected one "
f"of {self.array_types}")
# }}}
def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
if not key_to_pt_arrays:
# all cl arrays => no need to perform any codegen
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
key_to_pt_arrays)
normalized_expr, bound_arguments = _normalize_pt_expr(
pt_dict_of_named_arrays)
try:
pt_prg = self._freeze_prg_cache[normalized_expr]
except KeyError:
try:
transformed_dag, function_name = (
self._dag_transform_cache[normalized_expr])
except KeyError:
transformed_dag = self.transform_dag(normalized_expr)
from pytato.tags import PrefixNamed
name_hint_tags = []
for subary in key_to_pt_arrays.values():
name_hint_tags.extend(subary.tags_of_type(PrefixNamed))
from pytools import common_prefix
name_hint = common_prefix([nh.prefix for nh in name_hint_tags])
# All name_hint_tags shared at least some common prefix.
function_name = f"frozen_{name_hint}" if name_hint else "frozen_result"
self._dag_transform_cache[normalized_expr] = (
transformed_dag, function_name)
from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
opts = _DEFAULT_LOOPY_OPTIONS
assert opts.return_dict
pt_prg = pt.generate_loopy(transformed_dag,
options=opts,
cl_device=self.queue.device,
function_name=function_name,
target=self.get_target()
).bind_to_context(self.context)
pt_prg = pt_prg.with_transformed_translation_unit(
self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg
else:
transformed_dag, function_name = (
self._dag_transform_cache[normalized_expr])
assert len(pt_prg.bound_arguments) == 0
evt, out_dict = pt_prg(self.queue,
allocator=self.allocator,
**bound_arguments)
evt.wait()
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
key_to_frozen_subary = {
**key_to_frozen_subary,
**{k: to_tagged_cl_array(
v.with_queue(None),
axes=get_cl_axes_from_pt_axes(transformed_dag[k].expr.axes),
tags=transformed_dag[k].expr.tags)
for k, v in out_dict.items()}
}
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
def thaw(self, array):
import pytato as pt
import pyopencl.array as cla
cl_array = cla.to_device(self.queue, np_array)
return pt.make_data_wrapper(cl_array)
def to_numpy(self, array):
cl_array = self.freeze(array)
return cl_array.get(queue=self.queue)
import arraycontext.impl.pyopencl.taggable_cl_array as tga
from .utils import get_pt_axes_from_cl_axes
def _thaw(ary):
return pt.make_data_wrapper(ary.with_queue(self.queue),
axes=get_pt_axes_from_cl_axes(ary.axes),
tags=ary.tags)
return with_array_context(
self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
actx=self)
def freeze_thaw(self, array):
import pytato as pt
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _ft(ary):
if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)):
return ary
else:
raise _NotOnlyDataWrappers()
try:
return with_array_context(
self._rec_map_container(_ft, array),
actx=self)
except _NotOnlyDataWrappers:
return super().freeze_thaw(array)
def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
return ary.tagged(_preprocess_array_tags(tags))
return self._rec_map_container(_tag, array)
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
def _tag_axis(ary):
return ary.with_tagged_axis(iaxis, tags)
return self._rec_map_container(_tag_axis, array)
# }}}
# {{{ compilation
def call_loopy(self, program, **kwargs):
import pyopencl.array as cla
import pytato as pt
from pytato.loopy import call_loopy
from pytato.scalar_expr import SCALAR_CLASSES
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
entrypoint = program.default_entrypoint.name
# thaw frozen arrays
kwargs = {kw: (self.thaw(arg) if isinstance(arg, cla.Array) else arg)
for kw, arg in kwargs.items()}
# {{{ preprocess args
processed_kwargs = {}
for kw, arg in sorted(kwargs.items()):
if isinstance(arg, (pt.Array, *SCALAR_CLASSES)):
pass
elif isinstance(arg, TaggableCLArray):
arg = self.thaw(arg)
else:
raise ValueError(f"call_loopy argument '{kw}' expected to be an"
" instance of 'pytato.Array', 'Number' or"
f"'TaggableCLArray', got '{type(arg)}'")
processed_kwargs[kw] = arg
# }}}
return call_loopy(program, processed_kwargs, entrypoint)
def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
import pytato as pt
dag = pt.transform.materialize_with_mpms(dag)
return dag
def einsum(self, spec, *args, arg_names=None, tagged=()):
import pytato as pt
return call_loopy(program, kwargs, entrypoint)
import arraycontext.impl.pyopencl.taggable_cl_array as tga
if arg_names is None:
arg_names = (None,) * len(args)
def preprocess_arg(name, arg):
if isinstance(arg, tga.TaggableCLArray):
ary = self.thaw(arg)
elif isinstance(arg, self._frozen_array_types):
from warnings import warn
warn(f"Invoking {type(self).__name__}.einsum with"
f" {type(arg).__name__} will be unsupported in 2023. Use"
" `to_tagged_cl_array` to convert instances to TaggableCLArray.",
DeprecationWarning, stacklevel=2)
ary = self.thaw(tga.to_tagged_cl_array(arg))
elif isinstance(arg, pt.Array):
ary = arg
else:
raise TypeError(
f"{type(self).__name__}.einsum invoked with an unsupported "
f"array type: got '{type(arg).__name__}', but expected one "
f"of {self.array_types}")
if name is not None: # noqa: SIM102
# Tagging Placeholders with naming-related tags is pointless:
# They already have names. It's also counterproductive, as
# multiple placeholders with the same name that are not
# also the same object are not allowed, and this would produce
# a different Placeholder object of the same name.
if (not isinstance(ary, pt.Placeholder)
and not ary.tags_of_type(NameHint)):
ary = ary.tagged(NameHint(name))
return ary
return pt.einsum(spec, *[
preprocess_arg(name, arg)
for name, arg in zip(arg_names, args, strict=True)
]).tagged(_preprocess_array_tags(tagged))
def clone(self):
return type(self)(self.queue, self.allocator)
# }}}
# }}}
# {{{ PytatoJAXArrayContext
class PytatoJAXArrayContext(_BasePytatoArrayContext):
"""
An arraycontext that uses :mod:`pytato` to represent the thawed state of
the arrays and compiles the expressions using
:class:`pytato.target.python.JAXPythonTarget`.
"""
def __init__(self,
*,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
import jax.numpy as jnp
import pytato as pt
super().__init__(compile_trace_callback=compile_trace_callback)
self.array_types = (pt.Array, jnp.ndarray)
@property
def _frozen_array_types(self) -> tuple[type, ...]:
import jax.numpy as jnp
return (jnp.ndarray, )
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
if allowed_types is None:
allowed_types = self.array_types
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{type(self).__name__}.{func.__name__[1:]} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface
def from_numpy(self, array):
import jax
import pytato as pt
def _from_numpy(ary):
return pt.make_data_wrapper(jax.device_put(ary))
return with_array_context(
self._rec_map_container(_from_numpy, array, (np.ndarray,)),
actx=self)
def to_numpy(self, array):
import jax
def _to_numpy(ary):
return jax.device_get(ary)
return with_array_context(
self._rec_map_container(_to_numpy, self.freeze(array)),
actx=None)
def freeze(self, array):
# TODO: This should store a cache of pytato DAG -> build pyopencl
# program instead of re-compiling the DAG for every freeze.
if np.isscalar(array):
return array
import jax.numpy as jnp
import pytato as pt
import pyopencl.array as cla
if isinstance(array, cla.Array):
return array.with_queue(None)
if not isinstance(array, pt.Array):
raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
f"non-pytato array of type '{type(array)}'")
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
pt_prg = pt.generate_loopy(array, cl_device=self.queue.device)
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
array_as_dict: dict[str, jnp.ndarray | pt.Array] = {}
key_to_frozen_subary: dict[str, jnp.ndarray] = {}
key_to_pt_arrays: dict[str, pt.Array] = {}
evt, (cl_array,) = pt_prg(self.queue)
evt.wait()
def _record_leaf_ary_in_dict(key: tuple[Any, ...],
ary: jnp.ndarray | pt.Array) -> None:
key_str = "_ary" + _ary_container_key_stringifier(key)
array_as_dict[key_str] = ary
return cl_array.with_queue(None)
rec_keyed_map_array_container(_record_leaf_ary_in_dict, array)
# {{{ remove any non pytato arrays from array_as_dict
for key, subary in array_as_dict.items():
if isinstance(subary, jnp.ndarray):
key_to_frozen_subary[key] = subary.block_until_ready()
elif isinstance(subary, pt.DataWrapper):
# trivial freeze.
key_to_frozen_subary[key] = subary.data.block_until_ready()
elif isinstance(subary, pt.Array):
key_to_pt_arrays[key] = subary
else:
raise TypeError(
f"{type(self).__name__}.freeze invoked with an unsupported "
f"array type: got '{type(subary).__name__}', but expected one "
f"of {self.array_types}")
# }}}
def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
if not key_to_pt_arrays:
# all cl arrays => no need to perform any codegen
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays)
transformed_dag = self.transform_dag(pt_dict_of_named_arrays)
pt_prg = pt.generate_jax(transformed_dag, jit=True)
out_dict = pt_prg()
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
key_to_frozen_subary = {
**key_to_frozen_subary,
**{k: v.block_until_ready()
for k, v in out_dict.items()}
}
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
def thaw(self, array):
import pytato as pt
import pyopencl.array as cla
if not isinstance(array, cla.Array):
raise TypeError("PytatoPyOpenCLArrayContext.thaw expects CL arrays, got "
f"{type(array)}")
return pt.make_data_wrapper(array.with_queue(self.queue))
def _thaw(ary):
return pt.make_data_wrapper(ary)
# }}}
return with_array_context(
self._rec_map_container(_thaw, array, self._frozen_array_types),
actx=self)
def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller
return LazilyCompilingFunctionCaller(self, f)
from .compile import LazilyJAXCompilingFunctionCaller
return LazilyJAXCompilingFunctionCaller(self, f)
def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
import jax.numpy as jnp
if isinstance(ary, jnp.ndarray):
return ary
else:
return ary.tagged(_preprocess_array_tags(tags))
def transform_loopy_program(self, t_unit):
raise ValueError("PytatoPyOpenCLArrayContext does not implement "
"transform_loopy_program. Sub-classes are supposed "
"to implement it.")
return self._rec_map_container(_tag, array)
def tag(self, tags: Union[Sequence[Tag], Tag], array):
return array.tagged(tags)
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
def _tag_axis(ary):
import jax.numpy as jnp
if isinstance(ary, jnp.ndarray):
return ary
else:
return ary.with_tagged_axis(iaxis, tags)
def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
# TODO
from warnings import warn
warn("tagging PytatoPyOpenCLArrayContext's array axes: not yet implemented",
stacklevel=2)
return array
return self._rec_map_container(_tag_axis, array)
# }}}
# {{{ compilation
def call_loopy(self, program, **kwargs):
raise NotImplementedError(
"Calling loopy on JAX arrays is not supported. Maybe rewrite"
" the loopy kernel as numpy-flavored array operations using"
" ArrayContext.np.")
def einsum(self, spec, *args, arg_names=None, tagged=()):
import pyopencl.array as cla
import pytato as pt
if arg_names is not None:
from warnings import warn
warn("'arg_names' don't bear any significance in "
"PytatoPyOpenCLArrayContext.", stacklevel=2)
def preprocess_arg(arg):
if isinstance(arg, cla.Array):
return self.thaw(arg)
if arg_names is None:
arg_names = (None,) * len(args)
def preprocess_arg(name, arg):
import jax.numpy as jnp
if isinstance(arg, jnp.ndarray):
ary = self.thaw(arg)
elif isinstance(arg, pt.Array):
ary = arg
else:
assert isinstance(arg, pt.Array)
return arg
raise TypeError(
f"{type(self).__name__}.einsum invoked with an unsupported "
f"array type: got '{type(arg).__name__}', but expected one "
f"of {self.array_types}")
if name is not None: # noqa: SIM102
# Tagging Placeholders with naming-related tags is pointless:
# They already have names. It's also counterproductive, as
# multiple placeholders with the same name that are not
# also the same object are not allowed, and this would produce
# a different Placeholder object of the same name.
if (not isinstance(ary, pt.Placeholder)
and not ary.tags_of_type(NameHint)):
ary = ary.tagged(NameHint(name))
return ary
return pt.einsum(spec, *[
preprocess_arg(name, arg)
for name, arg in zip(arg_names, args, strict=True)
]).tagged(_preprocess_array_tags(tagged))
return pt.einsum(spec, *(preprocess_arg(arg) for arg in args))
def clone(self):
return type(self)()
@property
def permits_inplace_modification(self):
return False
# }}}
# }}}
# vim: foldmethod=marker
"""
.. currentmodule:: arraycontext.impl.pytato.compile
.. autoclass:: LazilyCompilingFunctionCaller
.. autoclass:: BaseLazilyCompilingFunctionCaller
.. autoclass:: LazilyPyOpenCLCompilingFunctionCaller
.. autoclass:: LazilyJAXCompilingFunctionCaller
.. autoclass:: CompiledFunction
.. autoclass:: FromArrayContextCompile
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
......@@ -27,25 +32,60 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from arraycontext.container import ArrayContainer
from arraycontext import PytatoPyOpenCLArrayContext
from arraycontext.container.traversal import (rec_keyed_map_array_container,
is_array_container)
import abc
import itertools
import logging
from collections.abc import Callable, Hashable, Mapping
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from typing import Any, Callable, Tuple, Dict, Mapping
from dataclasses import dataclass, field
from pyrsistent import pmap, PMap
from immutabledict import immutabledict
import pyopencl.array as cla
import pytato as pt
from pytools import ProcessLogger, to_identifier
from pytools.tag import Tag
from arraycontext.container import ArrayContainer, is_array_container_type
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.context import ArrayT
from arraycontext.impl.pytato import (
PytatoJAXArrayContext,
PytatoPyOpenCLArrayContext,
_BasePytatoArrayContext,
)
logger = logging.getLogger(__name__)
def _prg_id_to_kernel_name(f: Any) -> str:
if callable(f):
name = getattr(f, "__name__", "anonymous")
if not name.isidentifier():
return "actx_compiled_" + to_identifier(name)
else:
return name
else:
return to_identifier(str(f))
class FromArrayContextCompile(Tag):
"""
Tagged to the entrypoint kernel of every translation unit that is generated
by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
Typically this tag serves as a branch condition in implementing a
specialized transform strategy for kernels compiled by
:meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
"""
# {{{ helper classes: AbstractInputDescriptor
class AbstractInputDescriptor:
"""
Used internally in :class:`LazilyCompilingFunctionCaller` to characterize
Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize
an input.
"""
def __eq__(self, other):
......@@ -63,14 +103,16 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
@dataclass(frozen=True, eq=True)
class LeafArrayDescriptor(AbstractInputDescriptor):
dtype: np.dtype
shape: Tuple[int, ...]
shape: pt.array.ShapeType
# }}}
def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
# {{{ utilities
def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str:
"""
Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Stringifies an
Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
array-container's component's key. Goals of this routine:
* No two different keys should have the same stringification
......@@ -78,7 +120,7 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
* (informal) Shorter identifiers are preferred
"""
def _rec_str(key: Any) -> str:
if isinstance(key, (str, int)):
if isinstance(key, str | int):
return str(key)
elif isinstance(key, tuple):
# t in '_actx_t': stands for tuple
......@@ -90,70 +132,129 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
return "_".join(_rec_str(key) for key in keys)
def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...]
) -> "Tuple[PMap[Tuple[Any, ...],\
Any],\
PMap[Tuple[Any, ...],\
AbstractInputDescriptor]\
]":
def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...],
kwargs: Mapping[str, Any]
) -> \
tuple[Mapping[tuple[Hashable, ...], Any],
Mapping[tuple[Hashable, ...], AbstractInputDescriptor]]:
"""
Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Extracts
Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts
mappings from argument id to argument values and from argument id to
:class:`AbstractInputDescriptor`. See
:attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's
representation.
"""
arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {}
arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {}
arg_id_to_arg: dict[tuple[Hashable, ...], Any] = {}
arg_id_to_descr: dict[tuple[Hashable, ...], AbstractInputDescriptor] = {}
for iarg, arg in enumerate(args):
for kw, arg in itertools.chain(enumerate(args),
kwargs.items()):
if np.isscalar(arg):
arg_id = (iarg,)
arg_id = (kw,)
arg_id_to_arg[arg_id] = arg
arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(arg))
elif is_array_container(arg):
arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
elif is_array_container_type(arg.__class__):
def id_collector(keys, ary):
arg_id = (iarg,) + keys
arg_id = (kw, *keys) # noqa: B023
arg_id_to_arg[arg_id] = ary
arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(ary.dtype),
ary.shape)
arg_id_to_descr[arg_id] = LeafArrayDescriptor(
np.dtype(ary.dtype), ary.shape)
return ary
rec_keyed_map_array_container(id_collector, arg)
elif isinstance(arg, pt.Array):
arg_id = (kw,)
arg_id_to_arg[arg_id] = arg
arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype),
arg.shape)
else:
raise ValueError("Argument to a compiled operator should be"
" either a scalar or an array container. Got"
" either a scalar, pt.Array or an array container. Got"
f" '{arg}'.")
return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr)
def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
"""
Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder`
in :meth:`LazilyCompilingFunctionCaller.__call__`.
Preprocessing here refers to:
- Metadata Inference that is supplied via *actx*\'s
:meth:`PytatoPyOpenCLArrayContext.transform_dag`.
"""
import pyopencl.array as cla
from arraycontext.impl.pyopencl.taggable_cl_array import (
TaggableCLArray,
to_tagged_cl_array,
)
if isinstance(ary, pt.Array):
dag = pt.make_dict_of_named_arrays({"_actx_out": ary})
# Transform the DAG to give metadata inference a chance to do its job
return actx.transform_dag(dag)["_actx_out"].expr
elif isinstance(ary, TaggableCLArray):
return ary
elif isinstance(ary, cla.Array):
from warnings import warn
warn("Passing pyopencl.array.Array to a compiled callable"
" is deprecated and will stop working in 2023."
" Use `to_tagged_cl_array` to convert the array to"
" TaggableCLArray", DeprecationWarning, stacklevel=2)
return to_tagged_cl_array(ary,
axes=None,
tags=frozenset())
else:
raise NotImplementedError(type(ary))
def _get_f_placeholder_args(arg, iarg, arg_id_to_name):
def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
"""
Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the
Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the
placeholder version of an argument to
:attr:`LazilyCompilingFunctionCaller.f`.
:attr:`BaseLazilyCompilingFunctionCaller.f`.
"""
if np.isscalar(arg):
name = arg_id_to_name[(iarg,)]
return pt.make_placeholder(name, (), np.dtype(arg))
elif is_array_container(arg):
from pytato.tags import ForceValueArgTag
name = arg_id_to_name[kw,]
return pt.make_placeholder(name, (), np.dtype(type(arg)),
tags=frozenset({ForceValueArgTag()}))
elif isinstance(arg, pt.Array):
name = arg_id_to_name[kw,]
# Transform the DAG to give metadata inference a chance to do its job
arg = _to_input_for_compiled(arg, actx)
return pt.make_placeholder(name, arg.shape, arg.dtype,
axes=arg.axes,
tags=arg.tags)
elif is_array_container_type(arg.__class__):
def _rec_to_placeholder(keys, ary):
name = arg_id_to_name[(iarg,) + keys]
return pt.make_placeholder(name, ary.shape, ary.dtype)
return rec_keyed_map_array_container(_rec_to_placeholder,
arg)
index = (kw, *keys)
name = arg_id_to_name[index]
# Transform the DAG to give metadata inference a chance to do its job
ary = _to_input_for_compiled(ary, actx)
return pt.make_placeholder(name,
ary.shape,
ary.dtype,
axes=ary.axes,
tags=ary.tags)
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
else:
raise NotImplementedError(type(arg))
# }}}
# {{{ BaseLazilyCompilingFunctionCaller
@dataclass
class LazilyCompilingFunctionCaller:
class BaseLazilyCompilingFunctionCaller:
"""
Records a side-effect-free callable
:attr:`LazilyCompilingFunctionCaller.f` that can be specialized for the
input types with which :meth:`LazilyCompilingFunctionCaller.__call__` is
invoked.
Records a side-effect-free callable :attr:`f` that can be specialized for
the input types with which :meth:`__call__` is invoked.
.. attribute:: f
......@@ -162,23 +263,69 @@ class LazilyCompilingFunctionCaller:
.. automethod:: __call__
"""
actx: PytatoPyOpenCLArrayContext
actx: _BasePytatoArrayContext
f: Callable[..., Any]
program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]",
"CompiledFunction"] = field(default_factory=lambda: {})
program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor],
CompiledFunction] = field(default_factory=lambda: {})
def __call__(self, *args: Any) -> Any:
# {{{ abstract interface
def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
raise NotImplementedError
@property
def compiled_function_returning_array_container_class(
self) -> type[CompiledFunction]:
raise NotImplementedError
@property
def compiled_function_returning_array_class(self) -> type[CompiledFunction]:
raise NotImplementedError
# }}}
def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
input_id_to_name_in_program, output_id_to_name_in_program,
output_template):
if isinstance(ary_or_dict_of_named_arrays, pt.Array):
output_id = "_pt_out"
dict_of_named_arrays = pt.make_dict_of_named_arrays(
{output_id: ary_or_dict_of_named_arrays})
pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
self._dag_to_transformed_pytato_prg(dict_of_named_arrays,
prg_id=self.f))
return self.compiled_function_returning_array_class(
self.actx, pytato_program,
input_id_to_name_in_program=input_id_to_name_in_program,
output_tags=name_in_program_to_tags[output_id],
output_axes=name_in_program_to_axes[output_id],
output_name=output_id)
elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
self._dag_to_transformed_pytato_prg(ary_or_dict_of_named_arrays,
prg_id=self.f))
return self.compiled_function_returning_array_container_class(
self.actx, pytato_program,
input_id_to_name_in_program=input_id_to_name_in_program,
output_id_to_name_in_program=output_id_to_name_in_program,
name_in_program_to_tags=name_in_program_to_tags,
name_in_program_to_axes=name_in_program_to_axes,
output_template=output_template)
else:
raise NotImplementedError(type(ary_or_dict_of_named_arrays))
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s
Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s
function application on *args*.
Before applying :attr:`~LazilyCompilingFunctionCaller.f`, it is compiled
Before applying :attr:`~BaseLazilyCompilingFunctionCaller.f`, it is compiled
to a :mod:`pytato` DAG that would apply
:attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
:attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
The intermediary pytato DAG for *args* is memoized in *self*.
"""
from pytato.target.loopy import BoundPyOpenCLProgram
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args)
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
args, kwargs)
try:
compiled_f = self.program_cache[arg_id_to_descr]
......@@ -188,70 +335,288 @@ class LazilyCompilingFunctionCaller:
return compiled_f(arg_id_to_arg)
dict_of_named_arrays = {}
# output_naming_map: result id to name of the named array in the
# generated pytato DAG.
output_naming_map = {}
# input_naming_map: argument id to placeholder name in the generated
# pytato DAG.
input_naming_map = {
output_id_to_name_in_program = {}
input_id_to_name_in_program = {
arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}"
for arg_id in arg_id_to_arg}
outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map)
for iarg, arg in enumerate(args)])
output_template = self.f(
*[_get_f_placeholder_args(arg, iarg,
input_id_to_name_in_program, self.actx)
for iarg, arg in enumerate(args)],
**{kw: _get_f_placeholder_args(arg, kw,
input_id_to_name_in_program,
self.actx)
for kw, arg in kwargs.items()})
if not is_array_container(outputs):
self.actx._compile_trace_callback(self.f, "post_trace", output_template)
if (not (is_array_container_type(output_template.__class__)
or isinstance(output_template, pt.Array))):
# TODO: We could possibly just short-circuit this interface if the
# returned type is a scalar. Not sure if it's worth it though.
raise NotImplementedError(
f"Function '{self.f.__name__}' to be compiled "
"did not return an array container, but an instance of "
f"'{outputs.__class__}' instead.")
"did not return an array container or pt.Array,"
f" but an instance of '{output_template.__class__}' instead.")
def _as_dict_of_named_arrays(keys, ary):
name = "_pt_out_" + "_".join(str(key)
for key in keys)
output_naming_map[keys] = name
name = "_pt_out_" + _ary_container_key_stringifier(keys)
output_id_to_name_in_program[keys] = name
dict_of_named_arrays[name] = ary
return ary
rec_keyed_map_array_container(_as_dict_of_named_arrays,
outputs)
output_template)
pytato_program = pt.generate_loopy(dict_of_named_arrays,
options={"return_dict": True},
cl_device=self.actx.queue.device)
assert isinstance(pytato_program, BoundPyOpenCLProgram)
compiled_func = self._dag_to_compiled_func(
pt.make_dict_of_named_arrays(dict_of_named_arrays),
input_id_to_name_in_program=input_id_to_name_in_program,
output_id_to_name_in_program=output_id_to_name_in_program,
output_template=output_template)
pytato_program = (pytato_program
.with_transformed_program(self
.actx
.transform_loopy_program))
self.program_cache[arg_id_to_descr] = compiled_func
return compiled_func(arg_id_to_arg)
self.program_cache[arg_id_to_descr] = CompiledFunction(
self.actx, pytato_program,
input_naming_map, output_naming_map,
output_template=outputs)
# }}}
return self.program_cache[arg_id_to_descr](arg_id_to_arg)
# {{{ LazilyPyOpenCLCompilingFunctionCaller
@dataclass
class CompiledFunction:
class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
actx: PytatoPyOpenCLArrayContext
@property
def compiled_function_returning_array_container_class(
self) -> type[CompiledFunction]:
return CompiledPyOpenCLFunctionReturningArrayContainer
@property
def compiled_function_returning_array_class(self) -> type[CompiledFunction]:
return CompiledPyOpenCLFunctionReturningArray
def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
if prg_id is None:
prg_id = self.f
from pytato.target.loopy import BoundPyOpenCLExecutable
self.actx._compile_trace_callback(
prg_id, "pre_transform_dag", dict_of_named_arrays)
with ProcessLogger(logger, f"transform_dag for '{prg_id}'"):
pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
self.actx._compile_trace_callback(
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
name_in_program_to_tags = {
name: out.tags
for name, out in pt_dict_of_named_arrays._data.items()}
name_in_program_to_axes = {
name: out.axes
for name, out in pt_dict_of_named_arrays._data.items()}
self.actx._compile_trace_callback(
prg_id, "pre_generate_loopy", pt_dict_of_named_arrays)
with ProcessLogger(logger, f"generate_loopy for '{prg_id}'"):
from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
opts = _DEFAULT_LOOPY_OPTIONS
assert opts.return_dict
pytato_program = pt.generate_loopy(
pt_dict_of_named_arrays,
options=opts,
function_name=_prg_id_to_kernel_name(prg_id),
target=self.actx.get_target(),
).bind_to_context(self.actx.context) # pylint: disable=no-member
assert isinstance(pytato_program, BoundPyOpenCLExecutable)
self.actx._compile_trace_callback(
prg_id, "post_generate_loopy", pytato_program)
self.actx._compile_trace_callback(
prg_id, "pre_transform_loopy_program", pytato_program)
with ProcessLogger(logger, f"transform_loopy_program for '{prg_id}'"):
pytato_program = (pytato_program
.with_transformed_translation_unit(
lambda x: x.with_kernel(
x.default_entrypoint
.tagged(FromArrayContextCompile()))))
pytato_program = (pytato_program
.with_transformed_translation_unit(
self.actx.transform_loopy_program))
self.actx._compile_trace_callback(
prg_id, "post_transform_loopy_program", pytato_program)
self.actx._compile_trace_callback(
prg_id, "final", pytato_program)
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
# }}}
# {{{ preserve back compat
class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller):
def __new__(cls, *args, **kwargs):
from warnings import warn
warn("LazilyCompilingFunctionCaller has been renamed to"
" LazilyPyOpenCLCompilingFunctionCaller. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return super().__new__(cls)
def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
from warnings import warn
warn("_dag_to_transformed_loopy_prg has been renamed to"
" _dag_to_transformed_pytato_prg. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return super()._dag_to_transformed_pytato_prg(dict_of_named_arrays)
# }}}
# {{{ LazilyJAXCompilingFunctionCaller
class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
@property
def compiled_function_returning_array_container_class(
self) -> type[CompiledFunction]:
return CompiledJAXFunctionReturningArrayContainer
@property
def compiled_function_returning_array_class(self) -> type[CompiledFunction]:
return CompiledJAXFunctionReturningArray
def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
if prg_id is None:
prg_id = self.f
self.actx._compile_trace_callback(
prg_id, "pre_transform_dag", dict_of_named_arrays)
with ProcessLogger(logger, f"transform_dag for '{prg_id}'"):
pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
self.actx._compile_trace_callback(
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
name_in_program_to_tags = {
name: out.tags
for name, out in pt_dict_of_named_arrays._data.items()}
name_in_program_to_axes = {
name: out.axes
for name, out in pt_dict_of_named_arrays._data.items()}
self.actx._compile_trace_callback(
prg_id, "pre_generate_jax", pt_dict_of_named_arrays)
with ProcessLogger(logger, f"generate_jax for '{prg_id}'"):
pytato_program = pt.generate_jax(
pt_dict_of_named_arrays,
jit=True,
function_name=_prg_id_to_kernel_name(prg_id))
self.actx._compile_trace_callback(
prg_id, "post_generate_jax", pytato_program)
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
input_kwargs_for_loopy = {}
for arg_id, arg in arg_id_to_arg.items():
if np.isscalar(arg):
if isinstance(actx, PytatoPyOpenCLArrayContext):
# Scalar kernel args are passed as lp.ValueArgs
pass
elif isinstance(actx, PytatoJAXArrayContext):
import jax
arg = jax.device_put(arg)
else:
raise NotImplementedError(type(actx))
elif isinstance(arg, pt.array.DataWrapper):
# got a Datawrapper => simply gets its data
arg = arg.data
elif isinstance(arg, actx._frozen_array_types):
# got a frozen array => do nothing
pass
elif isinstance(arg, pt.Array):
# got an array expression => evaluate it
from warnings import warn
warn(f"Argument array '{arg_id}' to a compiled function is "
"unevaluated. Evaluating just-in-time, at "
"considerable expense. This is deprecated and will stop "
"working in 2023. To avoid this warning, force evaluation "
"of all arguments via freeze/thaw.",
DeprecationWarning, stacklevel=4)
arg = actx.freeze(arg)
else:
raise NotImplementedError(type(arg))
input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg
return input_kwargs_for_loopy
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
from warnings import warn
warn("_args_to_cl_buffer has been renamed to"
" _args_to_device_buffers. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return _args_to_device_buffers(actx, input_id_to_name_in_program,
arg_id_to_arg)
# }}}
# {{{ compiled function
class CompiledFunction(abc.ABC):
"""
A callable which captures the :class:`pytato.target.BoundProgram` resulting
from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of
A callable which captures the :class:`pytato.target.BoundProgram` resulting
from calling :attr:`~BaseLazilyCompilingFunctionCaller.f` with a given set of
input types, and generating :mod:`loopy` IR from it.
.. attribute:: pytato_program
.. attribute:: input_id_to_name_in_program
A mapping from input id to the placholder name in
A mapping from input id to the placeholder name in
:attr:`CompiledFunction.pytato_program`. Input id is represented as the
position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented
position of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s argument augmented
with the leaf array's key if the argument is an array container.
.. automethod:: __call__
"""
@abc.abstractmethod
def __call__(self, arg_id_to_arg) -> Any:
"""
:arg arg_id_to_arg: Mapping from input id to the passed argument. See
:attr:`CompiledFunction.input_id_to_name_in_program` for input id's
representation.
"""
pass
# }}}
# {{{ compiled pyopencl function
@dataclass(frozen=True)
class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
"""
.. attribute:: output_id_to_name_in_program
A mapping from output id to the name of
......@@ -265,55 +630,144 @@ class CompiledFunction:
An instance of :class:`arraycontext.ArrayContainer` that is the return
type of the callable.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer
def __call__(self, arg_id_to_arg) -> ArrayContainer:
"""
:arg arg_id_to_arg: Mapping from input id to the passed argument. See
:attr:`CompiledFunction.input_id_to_name_in_program` for input id's
representation.
"""
from arraycontext.container.traversal import rec_keyed_map_array_container
from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
input_kwargs_to_loopy = {}
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
# {{{ preprocess args to get arguments (CL buffers) to be fed to the
# loopy program
evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
**input_kwargs_for_loopy)
for arg_id, arg in arg_id_to_arg.items():
if np.isscalar(arg):
arg = cla.to_device(self.actx.queue, np.array(arg))
elif isinstance(arg, pt.array.DataWrapper):
# got a Datwwrapper => simply gets its data
arg = arg.data
elif isinstance(arg, cla.Array):
# got a frozen array => do nothing
pass
elif isinstance(arg, pt.Array):
# got an array expression => evaluate it
arg = self.actx.freeze(arg).with_queue(self.actx.queue)
else:
raise NotImplementedError(type(arg))
# FIXME Kernels (for now) allocate tons of memory in temporaries. If we
# race too far ahead with enqueuing, there is a distinct risk of
# running out of memory. This mitigates that risk a bit, for now.
evt.wait()
def to_output_template(keys, _):
name_in_program = self.output_id_to_name_in_program[keys]
return self.actx.thaw(to_tagged_cl_array(
out_dict[name_in_program],
axes=get_cl_axes_from_pt_axes(
self.name_in_program_to_axes[name_in_program]),
tags=self.name_in_program_to_tags[name_in_program]))
return rec_keyed_map_array_container(to_output_template,
self.output_template)
@dataclass(frozen=True)
class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
"""
.. attribute:: output_name_in_program
Name of the output array in the program.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
output_name: str
def __call__(self, arg_id_to_arg) -> ArrayContainer:
from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
input_kwargs_to_loopy[self.input_id_to_name_in_program[arg_id]] = arg
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
**input_kwargs_to_loopy)
**input_kwargs_for_loopy)
# FIXME Kernels (for now) allocate tons of memory in temporaries. If we
# race too far ahead with enqueuing, there is a distinct risk of
# running out of memory. This mitigates that risk a bit, for now.
evt.wait()
# }}}
return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
axes=get_cl_axes_from_pt_axes(
self.output_axes),
tags=self.output_tags))
# }}}
# {{{ compiled jax function
@dataclass(frozen=True)
class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
"""
.. attribute:: output_id_to_name_in_program
A mapping from output id to the name of
:class:`pytato.array.NamedArray` in
:attr:`CompiledFunction.pytato_program`. Output id is represented by
the key of a leaf array in the array container
:attr:`CompiledFunction.output_template`.
.. attribute:: output_template
An instance of :class:`arraycontext.ArrayContainer` that is the return
type of the callable.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer
def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
out_dict = self.pytato_program(**input_kwargs_for_loopy)
def to_output_template(keys, _):
return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]])
return self.actx.thaw(
out_dict[self.output_id_to_name_in_program[keys]]
.block_until_ready()
)
return rec_keyed_map_array_container(to_output_template,
self.output_template)
@dataclass(frozen=True)
class CompiledJAXFunctionReturningArray(CompiledFunction):
"""
.. attribute:: output_name_in_program
Name of the output array in the program.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
output_name: str
def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
_evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)
return self.actx.thaw(out_dict[self.output_name])
# }}}
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
......@@ -22,23 +25,29 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import Any, cast
import numpy as np
from arraycontext.fake_numpy import (
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
)
from arraycontext.container.traversal import (
rec_multimap_array_container, rec_map_array_container,
rec_map_reduce_array_container,
)
import pytato as pt
from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import (
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
......@@ -47,91 +56,74 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
:mod:`pytato` does not define any memory layout for its arrays. See
:ref:`Pytato docs <pytato:memory-layout>` for more on this.
"""
_pt_unary_funcs = frozenset({
"sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10",
"sqrt", "abs", "isnan", "real", "imag", "conj",
"logical_not",
})
_pt_multi_ary_funcs = frozenset({
"arctan2", "equal", "greater", "greater_equal", "less", "less_equal",
"not_equal", "minimum", "maximum", "where", "logical_and", "logical_or",
})
def _get_fake_numpy_linalg_namespace(self):
return PytatoFakeNumpyLinalgNamespace(self._array_context)
def __getattr__(self, name):
pt_funcs = ["abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",
"sqrt", "exp"]
if name in pt_funcs:
if name in self._pt_unary_funcs:
from functools import partial
return partial(rec_map_array_container, getattr(pt, name))
return super().__getattr__(name)
def reshape(self, a, newshape):
return rec_multimap_array_container(pt.reshape, a, newshape)
def transpose(self, a, axes=None):
return rec_multimap_array_container(pt.transpose, a, axes)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(pt.concatenate, arrays, axis)
def ones_like(self, ary):
def _ones_like(subary):
return pt.ones(subary.shape, subary.dtype)
return self._new_like(ary, _ones_like)
def maximum(self, x, y):
return rec_multimap_array_container(pt.maximum, x, y)
def minimum(self, x, y):
return rec_multimap_array_container(pt.minimum, x, y)
def where(self, criterion, then, else_):
return rec_multimap_array_container(pt.where, criterion, then, else_)
def sum(self, a, dtype=None):
def _pt_sum(ary):
if dtype not in [ary.dtype, None]:
raise NotImplementedError
if name in self._pt_multi_ary_funcs:
from functools import partial
return partial(rec_multimap_array_container, getattr(pt, name))
return pt.sum(ary)
return super().__getattr__(name)
return rec_map_reduce_array_container(sum, _pt_sum, a)
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
def min(self, a):
return rec_map_reduce_array_container(
partial(reduce, pt.minimum), pt.amin, a)
# {{{ array creation routines
def max(self, a):
return rec_map_reduce_array_container(
partial(reduce, pt.maximum), pt.amax, a)
def zeros(self, shape, dtype):
return pt.zeros(shape, dtype)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: pt.stack(arrays=args, axis=axis),
*arrays)
def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(
array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)
# {{{ relational operators
return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
def equal(self, x, y):
return rec_multimap_array_container(pt.equal, x, y)
def ones_like(self, ary):
return self.full_like(ary, 1)
def not_equal(self, x, y):
return rec_multimap_array_container(pt.not_equal, x, y)
def full_like(self, ary, fill_value):
def _full_like(subary):
return pt.full(subary.shape, fill_value, subary.dtype).copy(
axes=subary.axes, tags=subary.tags)
def greater(self, x, y):
return rec_multimap_array_container(pt.greater, x, y)
return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)
def greater_equal(self, x, y):
return rec_multimap_array_container(pt.greater_equal, x, y)
def arange(self, *args: Any, **kwargs: Any):
return pt.arange(*args, **kwargs)
def less(self, x, y):
return rec_multimap_array_container(pt.less, x, y)
def full(self, shape, fill_value, dtype=None):
return pt.full(shape, fill_value, dtype)
def less_equal(self, x, y):
return rec_multimap_array_container(pt.less_equal, x, y)
# }}}
def conj(self, x):
return rec_multimap_array_container(pt.conj, x)
# {{{ array manipulation routines
def arctan2(self, y, x):
return rec_multimap_array_container(pt.arctan2, y, x)
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: pt.reshape(a, newshape, order=order),
a)
def ravel(self, a, order="C"):
"""
......@@ -155,4 +147,99 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
return rec_map_array_container(_rec_ravel, a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(pt.transpose, a, axes)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(pt.broadcast_to, shape=shape), array)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(pt.concatenate, arrays, axis)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: pt.stack(arrays=args, axis=axis),
*arrays)
# }}}
# {{{ logic functions
def all(self, a):
return rec_map_reduce_array_container(
partial(reduce, pt.logical_and),
lambda subary: pt.all(subary), a)
def any(self, a):
return rec_map_reduce_array_container(
partial(reduce, pt.logical_or),
lambda subary: pt.any(subary), a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context
# NOTE: not all backends support `bool` properly, so use `int8` instead
true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array:
if type(x) is not type(y):
return false_ary
try:
serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError:
assert isinstance(x, pt.Array)
assert isinstance(y, pt.Array)
if x.shape != y.shape:
return false_ary
else:
return pt.all(cast(pt.Array, pt.equal(x, y)))
else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce(
pt.logical_and,
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return cast(Array, rec_equal(a, b))
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
def _pt_sum(ary):
if dtype not in [ary.dtype, None]:
raise NotImplementedError
return pt.sum(ary, axis=axis)
return rec_map_reduce_array_container(sum, _pt_sum, a)
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
max = amax
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
min = amin
def absolute(self, a):
return self.abs(a)
def vdot(self, a: Array, b: Array):
return rec_multimap_array_container(pt.vdot, a, b)
# }}}
from __future__ import annotations
__doc__ = """
.. autofunction:: transfer_from_numpy
.. autofunction:: transfer_to_numpy
"""
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast
from pytato.array import (
AbstractResultWithNamedArrays,
Array,
Axis as PtAxis,
DataWrapper,
DictOfNamedArrays,
Placeholder,
SizeParam,
make_placeholder,
)
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.transform import ArrayOrNames, CopyMapper
from pytools import UniqueNameGenerator, memoize_method
from arraycontext import ArrayContext
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
if TYPE_CHECKING:
import loopy as lp
class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
"""
Helper mapper for :func:`normalize_pt_expr`. Every
:class:`pytato.DataWrapper` is replaced with a deterministic copy of
:class:`Placeholder`.
"""
def __init__(self) -> None:
super().__init__()
self.bound_arguments: dict[str, Any] = {}
self.vng = UniqueNameGenerator()
self.seen_inputs: set[str] = set()
def map_data_wrapper(self, expr: DataWrapper) -> Array:
if expr.name is not None:
if expr.name in self.seen_inputs:
raise ValueError("Got multiple inputs with the name"
f"{expr.name} => Illegal.")
self.seen_inputs.add(expr.name)
# Normalizing names so that more arrays can have the same normalized DAG.
from pytato.codegen import _generate_name_for_temp
name = _generate_name_for_temp(expr, self.vng, "_actx_dw")
self.bound_arguments[name] = expr.data
return make_placeholder(
name=name,
shape=tuple(cast(Array, self.rec(s)) if isinstance(s, Array) else s
for s in expr.shape),
dtype=expr.dtype,
axes=expr.axes,
tags=expr.tags)
def map_size_param(self, expr: SizeParam) -> Array:
raise NotImplementedError
def map_placeholder(self, expr: Placeholder) -> Array:
raise ValueError("Placeholders cannot appear in"
" DatawrapperToBoundPlaceholderMapper.")
def _normalize_pt_expr(
expr: DictOfNamedArrays
) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]:
"""
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
normalized form of *expr*, with all instances of
:class:`pytato.DataWrapper` replaced with instances of :class:`Placeholder`
named in a deterministic manner. The data corresponding to the placeholders
in *normalized_expr* is recorded in the mapping *bound_arguments*.
Deterministic naming of placeholders permits more effective caching of
equivalent graphs.
"""
normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
normalized_expr = normalize_mapper(expr)
assert isinstance(normalized_expr, AbstractResultWithNamedArrays)
return normalized_expr, normalize_mapper.bound_arguments
def get_pt_axes_from_cl_axes(axes: tuple[ClAxis, ...]) -> tuple[PtAxis, ...]:
return tuple(PtAxis(axis.tags) for axis in axes)
def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]:
return tuple(ClAxis(axis.tags) for axis in axes)
# {{{ arg-size-limiting loopy target
class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
def __init__(self, limit_arg_size_nbytes: int) -> None:
super().__init__()
self.limit_arg_size_nbytes = limit_arg_size_nbytes
@memoize_method
def get_loopy_target(self) -> lp.PyOpenCLTarget:
from loopy import PyOpenCLTarget
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)
# }}}
# {{{ Transfer mappers
class TransferFromNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx
def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np
if not isinstance(expr.data, np.ndarray):
raise ValueError("TransferFromNumpyMapper: tried to transfer data that "
"is already on the device")
# Ideally, this code should just do
# return self.actx.from_numpy(expr.data).tagged(expr.tags),
# but there seems to be no way to transfer the non_equality_tags in that case.
actx_ary = self.actx.from_numpy(expr.data)
assert isinstance(actx_ary, DataWrapper)
# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=actx_ary.data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
class TransferToNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx
def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np
import arraycontext.impl.pyopencl.taggable_cl_array as tga
if not isinstance(expr.data, tga.TaggableCLArray):
raise ValueError("TransferToNumpyMapper: tried to transfer data that "
"is already on the host")
np_data = self.actx.to_numpy(expr.data)
assert isinstance(np_data, np.ndarray)
# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
# type-ignore: discussed at
# https://github.com/inducer/arraycontext/pull/289#discussion_r1855523967
# possibly related: https://github.com/python/mypy/issues/17375
return DataWrapper( # type: ignore[call-arg]
data=np_data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
return TransferFromNumpyMapper(actx)(expr)
def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
return TransferToNumpyMapper(actx)(expr)
# }}}
# vim: foldmethod=marker
......@@ -2,6 +2,8 @@
.. currentmodule:: arraycontext
.. autofunction:: make_loopy_program
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
......@@ -27,8 +29,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Mapping
from typing import ClassVar
import numpy as np
import loopy as lp
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
from pytools import memoize_in
from arraycontext.container.traversal import multimapped_over_array_containers
from arraycontext.fake_numpy import BaseFakeNumpyNamespace
# {{{ loopy
......@@ -64,9 +75,96 @@ def get_default_entrypoint(t_unit):
except AttributeError:
try:
return t_unit.root_kernel
except AttributeError:
except AttributeError as err:
raise TypeError("unable to find default entry point for loopy "
"translation unit")
"translation unit") from err
def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
@memoize_in(actx, _get_scalar_func_loopy_program)
def get(c_name, nargs, naxes):
from pymbolic.primitives import Subscript, Variable
var_names = [f"i{i}" for i in range(naxes)]
size_names = [f"n{i}" for i in range(naxes)]
subscript = tuple(Variable(vname) for vname in var_names)
from islpy import make_zero_and_vars
v = make_zero_and_vars(var_names, params=size_names)
domain = v[0].domain()
for vname, sname in zip(var_names, size_names, strict=True):
domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])
domain_bset, = domain.get_basic_sets()
import loopy as lp
from arraycontext.transform_metadata import ElementwiseMapKernelTag
def sub(name: str) -> Variable | Subscript:
return Subscript(Variable(name), subscript) if subscript else Variable(name)
return make_loopy_program(
[domain_bset], [
lp.Assignment(
sub("out"),
Variable(c_name)(*[sub(f"inp{i}") for i in range(nargs)]))
], [
lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto)
] + [
lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto)
for i in range(nargs)
] + [...],
name=f"actx_special_{c_name}",
tags=(ElementwiseMapKernelTag(),))
return get(c_name, nargs, naxes)
class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
_numpy_to_c_arc_functions: ClassVar[Mapping[str, str]] = {
"arcsin": "asin",
"arccos": "acos",
"arctan": "atan",
"arctan2": "atan2",
"arcsinh": "asinh",
"arccosh": "acosh",
"arctanh": "atanh",
}
_c_to_numpy_arc_functions: ClassVar[Mapping[str, str]] = {c_name: numpy_name
for numpy_name, c_name in _numpy_to_c_arc_functions.items()}
def __getattr__(self, name):
def loopy_implemented_elwise_func(*args):
if all(np.isscalar(ary) for ary in args):
return getattr(
np, self._c_to_numpy_arc_functions.get(name, name)
)(*args)
actx = self._array_context
prg = _get_scalar_func_loopy_program(actx,
c_name, nargs=len(args), naxes=len(args[0].shape))
outputs = actx.call_loopy(prg,
**{f"inp{i}": arg for i, arg in enumerate(args)})
return outputs["out"]
if name in self._c_to_numpy_arc_functions:
raise RuntimeError(f"'{name}' in ArrayContext.np has been removed. "
f"Use '{self._c_to_numpy_arc_functions[name]}' as in numpy. ")
# normalize to C names anyway
c_name = self._numpy_to_c_arc_functions.get(name, name)
# limit which functions we try to hand off to loopy
if (name in self._numpy_math_functions
or name in self._c_to_numpy_arc_functions):
return multimapped_over_array_containers(loopy_implemented_elwise_func)
else:
raise AttributeError(
f"'{type(self._array_context).__name__}.np' object "
f"has no attribute '{name}'")
# }}}
......
"""
.. autoclass:: NameHint
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
......@@ -22,36 +28,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import sys
from pytools.tag import Tag
from warnings import warn
from dataclasses import dataclass
from pytools.tag import UniqueTag
# {{{ deprecation handling
try:
from meshmode.transform_metadata import FirstAxisIsElementsTag \
as _FirstAxisIsElementsTag
except ImportError:
# placeholder in case meshmode is too old to have it.
class _FirstAxisIsElementsTag(Tag): # type: ignore[no-redef]
pass
@dataclass(frozen=True)
class NameHint(UniqueTag):
"""A tag acting on arrays or array axes. Express that :attr:`name` is a
useful starting point in forming an identifier for the tagged object.
.. attribute:: name
if sys.version_info >= (3, 7):
def __getattr__(name):
if name == "FirstAxisIsElementsTag":
warn(f"'arraycontext.{name}' is deprecated. "
f"Use 'meshmode.transform_metadata.{name}' instead. "
f"'arraycontext.{name}' will continue to work until 2022.",
DeprecationWarning, stacklevel=2)
return _FirstAxisIsElementsTag
else:
raise AttributeError(name)
else:
FirstAxisIsElementsTag = _FirstAxisIsElementsTag
A string. Must be a valid Python identifier. Not necessarily unique.
"""
name: str
# }}}
def __post_init__(self):
if not self.name.isidentifier():
raise ValueError("'name' must be an identifier")
# vim: foldmethod=marker
"""
.. currentmodule:: arraycontext
.. autoclass:: PytestArrayContextFactory
.. autoclass:: PytestPyOpenCLArrayContextFactory
.. autofunction:: pytest_generate_tests_for_array_contexts
.. autofunction:: pytest_generate_tests_for_pyopencl_array_context
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
......@@ -31,15 +33,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import Any, Callable, Dict, Sequence, Type, Union
from collections.abc import Callable, Sequence
from typing import Any
import pyopencl as cl
from arraycontext import NumpyArrayContext
from arraycontext.context import ArrayContext
# {{{ array context factories
class PytestPyOpenCLArrayContextFactory:
class PytestArrayContextFactory:
@classmethod
def is_available(cls) -> bool:
return True
def __call__(self) -> ArrayContext:
raise NotImplementedError
class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
"""
.. automethod:: __init__
.. automethod:: __call__
......@@ -51,6 +63,14 @@ class PytestPyOpenCLArrayContextFactory:
"""
self.device = device
@classmethod
def is_available(cls) -> bool:
try:
import pyopencl # noqa: F401
return True
except ImportError:
return False
def get_command_queue(self):
# Get rid of leftovers from past tests.
# CL implementations are surprisingly limited in how many
......@@ -61,17 +81,25 @@ class PytestPyOpenCLArrayContextFactory:
from gc import collect
collect()
import pyopencl as cl
# On Intel CPU CL, existence of a command queue does not ensure that
# the context survives.
ctx = cl.Context([self.device])
return ctx, cl.CommandQueue(ctx)
def __call__(self) -> ArrayContext:
raise NotImplementedError
class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
force_device_scalars = True
# Deprecated, remove in 2025.
_force_device_scalars = True
@property
def force_device_scalars(self):
from warnings import warn
warn(
"force_device_scalars is deprecated and will be removed in 2025.",
DeprecationWarning, stacklevel=2)
return self._force_device_scalars
@property
def actx_class(self):
......@@ -84,31 +112,45 @@ class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFact
# holding a reference to the context to keep it alive in turn.
# On some implementations (notably Intel CPU), holding a reference
# to a queue does not keep the context alive.
ctx, queue = self.get_command_queue()
_ctx, queue = self.get_command_queue()
alloc = None
if queue.device.platform.name == "NVIDIA CUDA":
from pyopencl.tools import ImmediateAllocator
alloc = ImmediateAllocator(queue)
from warnings import warn
warn("Disabling SVM due to memory leak "
"in Nvidia CL when running pytest. "
"See https://github.com/inducer/arraycontext/issues/196",
stacklevel=1)
return self.actx_class(
queue,
force_device_scalars=self.force_device_scalars)
allocator=alloc)
def __str__(self):
return ("<%s for <pyopencl.Device '%s' on '%s'>>" %
(
self.actx_class.__name__,
self.device.name.strip(),
self.device.platform.name.strip()))
return (f"<{self.actx_class.__name__} "
f"for <pyopencl.Device '{self.device.name.strip()}' "
f"on '{self.device.platform.name.strip()}'>>")
class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
_PytestPyOpenCLArrayContextFactoryWithClass):
force_device_scalars = False
class _PytestPytatoPyOpenCLArrayContextFactory(
PytestPyOpenCLArrayContextFactory):
class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
@classmethod
def is_available(cls) -> bool:
try:
import pyopencl # noqa: F401
import pytato # noqa: F401
return True
except ImportError:
return False
@property
def actx_class(self):
from arraycontext import PytatoPyOpenCLArrayContext
return PytatoPyOpenCLArrayContext
actx_cls = PytatoPyOpenCLArrayContext
return actx_cls
def __call__(self):
# The ostensibly pointless assignment to *ctx* keeps the CL context alive
......@@ -116,28 +158,107 @@ class _PytestPytatoPyOpenCLArrayContextFactory(
# holding a reference to the context to keep it alive in turn.
# On some implementations (notably Intel CPU), holding a reference
# to a queue does not keep the context alive.
ctx, queue = self.get_command_queue()
return self.actx_class(queue)
_ctx, queue = self.get_command_queue()
alloc = None
if queue.device.platform.name == "NVIDIA CUDA":
from pyopencl.tools import ImmediateAllocator
alloc = ImmediateAllocator(queue)
from warnings import warn
warn("Disabling SVM due to memory leak "
"in Nvidia CL when running pytest. "
"See https://github.com/inducer/arraycontext/issues/196",
stacklevel=1)
return self.actx_class(queue, allocator=alloc)
def __str__(self):
return ("<PytatoPyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>>" %
(
self.device.name.strip(),
self.device.platform.name.strip()))
return ("<PytatoPyOpenCLArrayContext for "
f"<pyopencl.Device '{self.device.name.strip()}' "
f"on '{self.device.platform.name.strip()}'>>")
_ARRAY_CONTEXT_FACTORY_REGISTRY: \
Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = {
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
"pyopencl-deprecated":
_PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars,
"pytato-pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
}
class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
pass
@classmethod
def is_available(cls) -> bool:
try:
import jax # noqa: F401
return True
except ImportError:
return False
def __call__(self):
from jax import config
from arraycontext import EagerJAXArrayContext
config.update("jax_enable_x64", True)
return EagerJAXArrayContext()
def __str__(self):
return "<EagerJAXArrayContext>"
class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
pass
@classmethod
def is_available(cls) -> bool:
try:
import jax # noqa: F401
import pytato # noqa: F401
return True
except ImportError:
return False
def __call__(self):
from jax import config
from arraycontext import PytatoJAXArrayContext
config.update("jax_enable_x64", True)
return PytatoJAXArrayContext()
def __str__(self):
return "<PytatoJAXArrayContext>"
# {{{ _PytestArrayContextFactory
class _NumpyArrayContextForTests(NumpyArrayContext):
def transform_loopy_program(self, t_unit):
return t_unit
class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
super().__init__()
def __call__(self):
return _NumpyArrayContextForTests()
def __str__(self):
return "<NumpyArrayContext>"
# }}}
_ARRAY_CONTEXT_FACTORY_REGISTRY: dict[str, type[PytestArrayContextFactory]] = {
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
"eagerjax": _PytestEagerJaxArrayContextFactory,
"numpy": _PytestNumpyArrayContextFactory,
}
def register_pytest_array_context_factory(
name: str,
factory: Type[PytestPyOpenCLArrayContextFactory]) -> None:
factory: type[PytestArrayContextFactory]) -> None:
if name in _ARRAY_CONTEXT_FACTORY_REGISTRY:
raise ValueError(f"factory '{name}' already exists")
......@@ -149,7 +270,7 @@ def register_pytest_array_context_factory(
# {{{ pytest integration
def pytest_generate_tests_for_array_contexts(
factories: Sequence[Union[str, Type[PytestPyOpenCLArrayContextFactory]]], *,
factories: Sequence[str | type[PytestArrayContextFactory]], *,
factory_arg_name: str = "actx_factory",
) -> Callable[[Any], None]:
"""Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`.
......@@ -166,10 +287,7 @@ def pytest_generate_tests_for_array_contexts(
"pyopencl",
])
to use the :mod:`pyopencl`-based array context. For :mod:`pyopencl`-based
contexts :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl` is used
as a backend, which allows specifying the ``PYOPENCL_TEST`` environment
variable for device selection.
to use the :mod:`pyopencl`-based array context.
The environment variable ``ARRAYCONTEXT_TEST`` can also be used to
overwrite any chosen implementations through *factories*. This is a
......@@ -177,11 +295,7 @@ def pytest_generate_tests_for_array_contexts(
Current supported implementations include:
* ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`
with ``force_device_scalars=True``.
* ``"pyopencl-deprecated"``, which creates a
:class:`~arraycontext.PyOpenCLArrayContext` with
``force_device_scalars=False``.
* ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`.
* ``"pytato-pyopencl"``, which creates a
:class:`~arraycontext.PytatoPyOpenCLArrayContext`.
......@@ -217,9 +331,19 @@ def pytest_generate_tests_for_array_contexts(
else:
raise ValueError(f"unknown array contexts: {unknown_factories}")
unique_factories = set([
_ARRAY_CONTEXT_FACTORY_REGISTRY.get(factory, factory) # type: ignore[misc]
for factory in unique_factories])
available_factories = {
factory for key in unique_factories
for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY.get(key, key)]
if (
not isinstance(factory, str)
and issubclass(factory, PytestArrayContextFactory)
and factory.is_available())
}
from pytools import partition
pyopencl_factories, other_factories = partition(
lambda factory: issubclass(factory, PytestPyOpenCLArrayContextFactory),
available_factories)
# }}}
......@@ -234,6 +358,7 @@ def pytest_generate_tests_for_array_contexts(
return
arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values()
empty_arg_dict = dict.fromkeys(arg_values[0])
# }}}
......@@ -246,67 +371,34 @@ def pytest_generate_tests_for_array_contexts(
"'ctx_factory' / 'ctx_getter' as arguments.")
arg_values_with_actx = []
for arg_dict in arg_values:
if pyopencl_factories:
for arg_dict in arg_values:
arg_values_with_actx.extend([
{factory_arg_name: factory(arg_dict["device"]), **arg_dict}
for factory in pyopencl_factories
])
if other_factories:
arg_values_with_actx.extend([
{factory_arg_name: factory(arg_dict["device"]), **arg_dict}
for factory in unique_factories
{factory_arg_name: factory(), **empty_arg_dict}
for factory in other_factories
])
else:
arg_values_with_actx = arg_values
arg_value_tuples = [
tuple(arg_dict[name] for name in arg_names)
for arg_dict in arg_values_with_actx
]
# }}}
# Sort the actx's so that parallel pytest works
arg_value_tuples = sorted(arg_value_tuples, key=lambda x: x.__str__())
# NOTE: sorts the args so that parallel pytest works
arg_value_tuples = sorted([
tuple(arg_dict[name] for name in arg_names)
for arg_dict in arg_values_with_actx
], key=lambda x: str(x))
metafunc.parametrize(arg_names, arg_value_tuples, ids=ids)
return inner
def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None:
"""Parametrize tests for pytest to use a
:class:`~arraycontext.PyOpenCLArrayContext`.
Performs device enumeration analogously to
:func:`pyopencl.tools.pytest_generate_tests_for_pyopencl`.
Using the line:
.. code-block:: python
from arraycontext import (
pytest_generate_tests_for_pyopencl_array_context
as pytest_generate_tests)
in your pytest test scripts allows you to use the argument ``actx_factory``,
in your test functions, and they will automatically be
run once for each OpenCL device/platform in the system, as appropriate,
with an argument-less function that returns an
:class:`~arraycontext.ArrayContext` when called.
It also allows you to specify the ``PYOPENCL_TEST`` environment variable
for device selection.
"""
from warnings import warn
warn("pytest_generate_tests_for_pyopencl_array_context is deprecated. "
"Use 'pytest_generate_tests = "
"arraycontext.pytest_generate_tests_for_array_contexts"
"([\"pyopencl-deprecated\"])' instead. "
"pytest_generate_tests_for_pyopencl_array_context will stop working "
"in 2022.",
DeprecationWarning, stacklevel=2)
pytest_generate_tests_for_array_contexts([
"pyopencl-deprecated",
], factory_arg_name="actx_factory")(metafunc)
# }}}
......
......@@ -4,6 +4,8 @@
.. autoclass:: CommonSubexpressionTag
.. autoclass:: ElementwiseMapKernelTag
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
......
VERSION = (2021, 1)
VERSION_TEXT = ".".join(str(i) for i in VERSION)
from __future__ import annotations
from importlib import metadata
def _parse_version(version: str) -> tuple[tuple[int, ...], str]:
import re
m = re.match(r"^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT)
assert m is not None
return tuple(int(nr) for nr in m.group(1).split(".")), m.group(2)
VERSION_TEXT = metadata.version("arraycontext")
VERSION, VERSION_STATUS = _parse_version(VERSION_TEXT)
......@@ -4,17 +4,3 @@ The Array Context Abstraction
.. automodule:: arraycontext
.. automodule:: arraycontext.context
Implementations of the Array Context Abstraction
================================================
Array context based on :mod:`pyopencl.array`
--------------------------------------------
.. automodule:: arraycontext.impl.pyopencl
Lazy/Deferred evaluation array context based on :mod:`pytato`
-------------------------------------------------------------
.. automodule:: arraycontext.impl.pytato
# -- Path setup --------------------------------------------------------------
from importlib import metadata
from urllib.request import urlopen
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
_conf_url = \
"https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py"
with urlopen(_conf_url) as _inf:
exec(compile(_inf.read(), _conf_url, "exec"), globals())
# -- Project information -----------------------------------------------------
project = "arraycontext"
copyright = "2021, University of Illinois Board of Trustees"
author = "Arraycontext Contributors"
release = metadata.version("arraycontext")
version = ".".join(release.split(".")[:2])
ver_dic = {}
exec(
compile(
open("../arraycontext/version.py").read(),
"../arraycontext/version.py", "exec"),
ver_dic)
version = ".".join(str(x) for x in ver_dic["VERSION"])
# The full version, including alpha/beta/rc tags.
release = ver_dic["VERSION_TEXT"]
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
"sphinx.ext.graphviz",
"sphinx_copybutton",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
intersphinx_mapping = {
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"loopy": ("https://documen.tician.de/loopy", None),
"meshmode": ("https://documen.tician.de/meshmode", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pymbolic": ("https://documen.tician.de/pymbolic", None),
"pyopencl": ("https://documen.tician.de/pyopencl", None),
"pytato": ("https://documen.tician.de/pytato", None),
"pytest": ("https://docs.pytest.org/en/latest/", None),
"python": ("https://docs.python.org/3/", None),
"pytools": ("https://documen.tician.de/pytools", None),
}
# Some modules need to import things just so that sphinx can resolve symbols in
# type annotations. Often, we do not want these imports (e.g. of PyOpenCL) when
# in normal use (because they would introduce unintended side effects or hard
# dependencies). This flag exists so that these imports only occur during doc
# build. Since sphinx appears to resolve type hints lexically (as it should),
# this needs to be cross-module (since, e.g. an inherited arraycontext
# docstring can be read by sphinx when building meshmode, a dependent package),
# this needs a setting of the same name across all packages involved, that's
# why this name is as global-sounding as it is.
import sys
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "furo"
sys._BUILDING_SPHINX_DOCS = True
intersphinx_mapping = {
"https://docs.python.org/3/": None,
"https://numpy.org/doc/stable/": None,
"https://documen.tician.de/pytools": None,
"https://documen.tician.de/pyopencl": None,
"https://documen.tician.de/pytato": None,
"https://documen.tician.de/loopy": None,
"https://documen.tician.de/meshmode": None,
"https://docs.pytest.org/en/latest/": None,
}
autoclass_content = "class"
nitpick_ignore_regex = [
["py:class", r"arraycontext\.context\.ContainerOrScalarT"],
]
Implementations of the Array Context Abstraction
================================================
..
When adding a new array context here, make sure to also add it to and run
```
doc/make_numpy_coverage_table.py
```
to update the coverage table below!
Array context based on :mod:`numpy`
--------------------------------------------
.. automodule:: arraycontext.impl.numpy
Array context based on :mod:`pyopencl.array`
--------------------------------------------
.. automodule:: arraycontext.impl.pyopencl
Lazy/Deferred evaluation array context based on :mod:`pytato`
-------------------------------------------------------------
.. automodule:: arraycontext.impl.pytato
Array context based on :mod:`jax.numpy`
---------------------------------------
.. automodule:: arraycontext.impl.jax
.. _numpy-coverage:
:mod:`numpy` coverage
---------------------
This is a list of functionality implemented by :attr:`arraycontext.ArrayContext.np`.
.. note::
Only functions and methods that have at least one implementation are listed.
.. include:: numpy_coverage.rst
......@@ -7,6 +7,7 @@ implementations for:
- :mod:`numpy`
- :mod:`pyopencl`
- :mod:`jax.numpy`
- :mod:`pytato` (for lazy/deferred evaluation)
- Debugging
- Profiling
......@@ -14,11 +15,45 @@ implementations for:
:mod:`arraycontext` started life as an array abstraction for use with the
:mod:`meshmode` unstrucuted discretization package.
Design Guidelines
-----------------
Here are some of the guidelines we aim to follow in :mod:`arraycontext`. There
exist numerous other, related efforts, such as the `Python array API standard
<https://data-apis.org/array-api/latest/purpose_and_scope.html>`__. These
points may aid in clarifying and differentiating our objectives.
- The array context is about exposing the common subset of operations
available in immutable and mutable arrays. As a result, the interface
does *not* seek to support interfaces that provide, enable, or are typically
used only with in-place mutation.
For example: The equivalents of :func:`numpy.empty` were deprecated
and will eventually be removed.
- Each array context offers a specific subset of of :mod:`numpy` under
:attr:`arraycontext.ArrayContext.np`. Functions under this namespace
must be unconditionally :mod:`numpy`-compatible, that is, they may not
offer an interface beyond what numpy offers. Functions that are
incompatible, for example by supporting tag metadata
(cf. :meth:`arraycontext.ArrayContext.einsum`) should live under the
:class:`~arraycontext.ArrayContext` directly.
- Similarly, we strive to minimize redundancy between attributes of
:class:`~arraycontext.ArrayContext` and :attr:`arraycontext.ArrayContext.np`.
For example: ``ArrayContext.empty_like`` was deprecated.
- Array containers are data structures that may contain arrays.
See :mod:`arraycontext.container`. We strive to support these, where sensible,
in :class:`~arraycontext.ArrayContext` and :attr:`arraycontext.ArrayContext.np`.
Contents
--------
.. toctree::
array_context
implementations
container
other
misc
......
"""
Workflow:
1. If a new array context is implemented, it should be added to
:func:`initialize_contexts`.
2. If a new function is implemented, it should be added to the
corresponding ``write_section_name`` function.
3. Once everything is added, regenerate the tables using
.. code::
python make_numpy_support_table.py numpy_coverage.rst
"""
from __future__ import annotations
import pathlib
from mako.template import Template
import arraycontext
# {{{ templating
HEADER = """
.. raw:: html
<style> .red {color:red} </style>
<style> .green {color:green} </style>
.. role:: red
.. role:: green
"""
TABLE_TEMPLATE = Template("""
${title}
${'~' * len(title)}
.. list-table::
:header-rows: 1
* - Function
% for ctx in contexts:
- :class:`~arraycontext.${type(ctx).__name__}`
% endfor
% for name, (directive, in_context) in numpy_functions_for_context.items():
* - :${directive}:`numpy.${name}`
% for ctx in contexts:
<%
flag = in_context.get(type(ctx), "yes").capitalize()
color = "green" if flag == "Yes" else "red"
%> - :${color}:`${flag}`
% endfor
% endfor
""")
def initialize_contexts():
import pyopencl as cl
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
return [
arraycontext.PyOpenCLArrayContext(queue, force_device_scalars=True),
arraycontext.EagerJAXArrayContext(),
arraycontext.PytatoPyOpenCLArrayContext(queue),
arraycontext.PytatoJAXArrayContext(),
]
def build_supported_functions(funcs, contexts):
import numpy as np
numpy_functions_for_context = {}
for directive, name in funcs:
if not hasattr(np, name):
raise ValueError(f"'{name}' not found in numpy namespace")
in_context = {}
for ctx in contexts:
try:
_ = getattr(ctx.np, name)
except AttributeError:
in_context[type(ctx)] = "No"
numpy_functions_for_context[name] = (directive, in_context)
return numpy_functions_for_context
# }}}
# {{{ writing
def write_array_creation_routines(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.array-creation.html
funcs = (
# (sphinx-directive, name)
("func", "empty_like"),
("func", "ones_like"),
("func", "zeros_like"),
("func", "full_like"),
("func", "copy"),
)
r = TABLE_TEMPLATE.render(
title="Array creation routines",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_array_manipulation_routines(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.array-manipulation.html
funcs = (
# (sphinx-directive, name)
("func", "reshape"),
("func", "ravel"),
("func", "transpose"),
("func", "broadcast_to"),
("func", "concatenate"),
("func", "stack"),
)
r = TABLE_TEMPLATE.render(
title="Array manipulation routines",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_linear_algebra(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.linalg.html
funcs = (
# (sphinx-directive, name)
("func", "vdot"),
)
r = TABLE_TEMPLATE.render(
title="Linear algebra",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_logic_functions(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.logic.html
funcs = (
# (sphinx-directive, name)
("func", "all"),
("func", "any"),
("data", "greater"),
("data", "greater_equal"),
("data", "less"),
("data", "less_equal"),
("data", "equal"),
("data", "not_equal"),
)
r = TABLE_TEMPLATE.render(
title="Logic Functions",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_mathematical_functions(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.math.html
funcs = (
("data", "sin"),
("data", "cos"),
("data", "tan"),
("data", "arcsin"),
("data", "arccos"),
("data", "arctan"),
("data", "arctan2"),
("data", "sinh"),
("data", "cosh"),
("data", "tanh"),
("data", "floor"),
("data", "ceil"),
("func", "sum"),
("data", "exp"),
("data", "log"),
("data", "log10"),
("func", "real"),
("func", "imag"),
("data", "conjugate"),
("data", "maximum"),
("func", "amax"),
("data", "minimum"),
("func", "amin"),
("data", "sqrt"),
("data", "absolute"),
("data", "fabs"),
)
r = TABLE_TEMPLATE.render(
title="Mathematical functions",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
def write_searching_sorting_and_counting(outf, contexts):
# https://numpy.org/doc/stable/reference/routines.sort.html
funcs = (
("func", "where"),
)
r = TABLE_TEMPLATE.render(
title="Sorting, searching, and counting",
contexts=contexts,
numpy_functions_for_context=build_supported_functions(funcs, contexts),
)
outf.write(r)
# }}}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("filename", nargs="?", type=pathlib.Path, default=None)
args = parser.parse_args()
def write(outf):
outf.write(HEADER)
write_array_creation_routines(outf, ctxs)
write_array_manipulation_routines(outf, ctxs)
write_linear_algebra(outf, ctxs)
write_logic_functions(outf, ctxs)
write_mathematical_functions(outf, ctxs)
ctxs = initialize_contexts()
if args.filename:
with open(args.filename, "w") as outf:
write(outf)
else:
import sys
write(sys.stdout)
.. raw:: html
<style> .red {color:red} </style>
<style> .green {color:green} </style>
.. role:: red
.. role:: green
Array creation routines
~~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.empty_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.ones_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.zeros_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.full_like`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.copy`
- :green:`Yes`
- :green:`Yes`
- :red:`No`
- :red:`No`
Array manipulation routines
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.reshape`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.ravel`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.transpose`
- :red:`No`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.broadcast_to`
- :red:`No`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.concatenate`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.stack`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
Linear algebra
~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.vdot`
- :green:`Yes`
- :green:`Yes`
- :red:`No`
- :red:`No`
Logic Functions
~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :func:`numpy.all`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.any`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.greater`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.greater_equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.less`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.less_equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.not_equal`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
Mathematical functions
~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
* - Function
- :class:`~arraycontext.PyOpenCLArrayContext`
- :class:`~arraycontext.EagerJAXArrayContext`
- :class:`~arraycontext.PytatoPyOpenCLArrayContext`
- :class:`~arraycontext.PytatoJAXArrayContext`
* - :data:`numpy.sin`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.cos`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.tan`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arcsin`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arccos`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arctan`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.arctan2`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.sinh`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.cosh`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.tanh`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.floor`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.ceil`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.sum`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.exp`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.log`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.log10`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.real`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.imag`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.conjugate`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.maximum`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.amax`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.minimum`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :func:`numpy.amin`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.sqrt`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.absolute`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
* - :data:`numpy.fabs`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
- :green:`Yes`
Other functionality
===================
.. _metadata:
Metadata ("tags") for Arrays and Array Axes
-------------------------------------------
.. automodule:: arraycontext.metadata
:class:`~arraycontext.ArrayContext`-generating fixture for :mod:`pytest`
------------------------------------------------------------------------
......
#! /bin/sh
rsync --verbose --archive --delete _build/html/* doc-upload:doc/arraycontext
rsync --verbose --archive --delete _build/html/ doc-upload:doc/arraycontext
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "arraycontext"
version = "2024.0"
description = "Choose your favorite numpy-workalike"
readme = "README.rst"
license = { text = "MIT" }
authors = [
{ name = "Andreas Kloeckner", email = "inform@tiker.net" },
]
requires-python = ">=3.10"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Other Audience",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Software Development :: Libraries",
"Topic :: Utilities",
]
dependencies = [
"immutabledict>=4.1",
"numpy",
"pytools>=2024.1.3",
# for Self
"typing_extensions>=4",
]
[project.optional-dependencies]
jax = [
"jax>=0.4",
]
pyopencl = [
"islpy>=2024.1",
"loopy>=2024.1",
"pyopencl>=2024.1",
]
pytato = [
"pytato>=2021.1",
]
test = [
"mypy",
"pytest",
"ruff",
]
[project.urls]
Documentation = "https://documen.tician.de/arraycontext"
Homepage = "https://github.com/inducer/arraycontext"
[tool.ruff]
preview = true
[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle
"F", # pyflakes
"G", # flake8-logging-format
"I", # flake8-isort
"N", # pep8-naming
"NPY", # numpy
"Q", # flake8-quotes
"RUF", # ruff
"UP", # pyupgrade
"W", # pycodestyle
"SIM",
]
extend-ignore = [
"C90", # McCabe complexity
"E221", # multiple spaces before operator
"E226", # missing whitespace around arithmetic operator
"E402", # module-level import not at top of file
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
inline-quotes = "double"
multiline-quotes = "double"
[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = [
"jax",
"loopy",
"pymbolic",
"pyopencl",
"pytato",
"pytools",
]
known-local-folder = [
"arraycontext",
"testlib",
]
lines-after-imports = 2
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.per-file-ignores]
"doc/conf.py" = ["I002"]
# To avoid a requirement of array container definitions being someplace importable
# from @dataclass_array_container.
"test/test_utils.py" = ["I002"]
[tool.mypy]
python_version = "3.10"
warn_unused_ignores = true
# TODO: enable this
# check_untyped_defs = true
[[tool.mypy.overrides]]
module = [
"islpy.*",
"loopy.*",
"meshmode.*",
"pymbolic",
"pymbolic.*",
"pyopencl.*",
"jax.*",
]
ignore_missing_imports = true
[tool.typos.default]
extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"
]