diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index e8f34d45be656c8dfa96b1cc83ab52655597556c..d9861d61591c155ae33bb0d6d4b25b8eb5da4692 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -64,9 +64,13 @@ from .container.traversal import ( outer) from .impl.pyopencl import PyOpenCLArrayContext -from .impl.pytato import PytatoPyOpenCLArrayContext +from .impl.pytato import (PytatoPyOpenCLArrayContext, + PytatoJAXArrayContext, + _BasePytatoArrayContext) +from .impl.jax import EagerJAXArrayContext from .pytest import ( + PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory, pytest_generate_tests_for_array_contexts, pytest_generate_tests_for_pyopencl_array_context) @@ -102,9 +106,12 @@ __all__ = ( "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", + "PytatoJAXArrayContext", "_BasePytatoArrayContext", + "EagerJAXArrayContext", "make_loopy_program", + "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", "pytest_generate_tests_for_array_contexts", "pytest_generate_tests_for_pyopencl_array_context" diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72c9318c9be7f5b7d298a1fb4e47015fc87791a5 --- /dev/null +++ b/arraycontext/impl/jax/__init__.py @@ -0,0 +1,129 @@ +""" +.. currentmodule:: arraycontext +.. autoclass:: EagerJAXArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np + +from typing import Union, Callable, Any +from pytools.tag import ToTagSetConvertible +from arraycontext.context import ArrayContext, _ScalarLike + + +class EagerJAXArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses + :class:`jaxlib.xla_extension.DeviceArrayBase` instances for its base array + class and performs all array operations eagerly. See + :class:`~arraycontext.PytatoJAXArrayContext` for a lazier version. + + .. note:: + + JAX stores a global configuration state in :data:`jax.config`. Callers + are expected to maintain those. Most important for scientific computing + workloads being ``jax_enable_x64``. + """ + + def __init__(self) -> None: + super().__init__() + + from jax.numpy import DeviceArray + self.array_types = (DeviceArray, ) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import EagerJAXFakeNumpyNamespace + return EagerJAXFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def empty(self, shape, dtype): + import jax.numpy as jnp + return jnp.empty(shape=shape, dtype=dtype) + + def zeros(self, shape, dtype): + import jax.numpy as jnp + return jnp.zeros(shape=shape, dtype=dtype) + + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): + import jax + return jax.device_put(array) + + def to_numpy(self, array): + import jax + # jax.device_get can take scalars as well. + return jax.device_get(array) + + def call_loopy(self, t_unit, **kwargs): + raise NotImplementedError("calling loopy on JAX arrays" + " not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array" + " operations using ArrayContext.np.") + + def freeze(self, array): + return array.block_until_ready() + + def thaw(self, array): + return array + + # }}} + + def tag(self, tags: ToTagSetConvertible, array): + # Sorry, not capable. + return array + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + # TODO: See `jax.experiemental.maps.xmap`, proabably that should be useful? + return array + + def clone(self): + return type(self)() + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + return f + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import jax.numpy as jnp + if arg_names is not None: + from warnings import warn + warn("'arg_names' don't bear any significance in " + "EagerJAXArrayContext.", stacklevel=2) + + return jnp.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return False + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True + +# vim: foldmethod=marker diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..d0466eeeaf5de6b215708d6203b45ba60c60f3e3 --- /dev/null +++ b/arraycontext/impl/jax/fake_numpy.py @@ -0,0 +1,143 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +from arraycontext.fake_numpy import ( + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, + ) +from arraycontext.container.traversal import ( + rec_multimap_array_container, rec_map_array_container, + rec_map_reduce_array_container, + ) +from arraycontext.container import NotAnArrayContainerError, serialize_container +import numpy +import jax.numpy as jnp + + +class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`~arraycontext.EagerJAXArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return EagerJAXFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + return partial(rec_multimap_array_container, getattr(jnp, name)) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: jnp.reshape(ary, newshape, order=order), + a) + + def transpose(self, a, axes=None): + return rec_multimap_array_container(jnp.transpose, a, axes) + + def concatenate(self, arrays, axis=0): + return rec_multimap_array_container(jnp.concatenate, arrays, axis) + + def where(self, criterion, then, else_): + return rec_multimap_array_container(jnp.where, criterion, then, else_) + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, + partial(jnp.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: jnp.stack(arrays=args, axis=axis), + *arrays) + + def array_equal(self, a, b): + actx = self._array_context + + # NOTE: not all backends support `bool` properly, so use `int8` instead + false = actx.from_numpy(numpy.int8(False)) + + def rec_equal(x, y): + if type(x) != type(y): + return false + + try: + iterable = zip(serialize_container(x), serialize_container(y)) + except NotAnArrayContainerError: + if x.shape != y.shape: + return false + else: + return jnp.all(jnp.equal(x, y)) + else: + return reduce( + jnp.logical_and, + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] + ) + + return rec_equal(a, b) + + def ravel(self, a, order="C"): + """ + .. warning:: + + Since :func:`jax.numpy.reshape` does not support orders `A`` and + ``K``, in such cases we fallback to using ``order = C``. + """ + if order in "AK": + from warnings import warn + warn(f"ravel with order='{order}' not supported by JAX," + " using order=C.") + order = "C" + + return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order), + a) + + def vdot(self, x, y, dtype=None): + from arraycontext import rec_multimap_reduce_array_container + + def _rec_vdot(ary1, ary2): + if dtype not in [None, numpy.find_common_type((ary1.dtype, + ary2.dtype), + ())]: + raise NotImplementedError(f"{type(self)} cannot take dtype in" + " vdot.") + + return jnp.vdot(ary1, ary2) + + return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array) diff --git a/doc/array_context.rst b/doc/array_context.rst index 680d49bf793ade7aba8622e50f7f4f2734e3d289..c137c0899afa218e195a21369d3856f47b461f7e 100644 --- a/doc/array_context.rst +++ b/doc/array_context.rst @@ -19,6 +19,12 @@ Lazy/Deferred evaluation array context based on :mod:`pytato` .. automodule:: arraycontext.impl.pytato + +Array context :mod:`jax.numpy` +------------------------------------------------------------- + +.. automodule:: arraycontext.impl.jax + .. _numpy-coverage: :mod:`numpy` coverage diff --git a/doc/conf.py b/doc/conf.py index bee0e10b98630bf89addaabe7931fd5eeaf86478..c91572488c400ff40959254bf745ea1e7052639d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -24,6 +24,7 @@ intersphinx_mapping = { "https://documen.tician.de/loopy": None, "https://documen.tician.de/meshmode": None, "https://docs.pytest.org/en/latest/": None, + "https://jax.readthedocs.io/en/latest/": None, } import sys