From 5bf46bc61c3325981852244130544561c26be139 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sun, 26 Sep 2021 02:38:28 -0500 Subject: [PATCH] Implements NumpyArrayContext --- arraycontext/__init__.py | 7 +- arraycontext/container/arithmetic.py | 1 - arraycontext/impl/numpy/__init__.py | 126 +++++++++++++++++++++++ arraycontext/impl/numpy/fake_numpy.py | 143 ++++++++++++++++++++++++++ doc/implementations.rst | 5 + 5 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 arraycontext/impl/numpy/__init__.py create mode 100644 arraycontext/impl/numpy/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1d0efb3..2f2640d 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -78,6 +78,7 @@ from .context import ( tag_axes, ) from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program @@ -91,7 +92,6 @@ from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag __all__ = ( - "Array", "Array", "ArrayContainer", "ArrayContainerT", @@ -105,13 +105,13 @@ __all__ = ( "EagerJAXArrayContext", "ElementwiseMapKernelTag", "NotAnArrayContainerError", + "NumpyArrayContext", "PyOpenCLArrayContext", "PytatoJAXArrayContext", "PytatoPyOpenCLArrayContext", "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", "Scalar", - "Scalar", "ScalarLike", "dataclass_array_container", "deserialize_container", @@ -146,8 +146,9 @@ __all__ = ( "to_numpy", "unflatten", "with_array_context", + "with_container_arithmetic", "with_container_arithmetic" -) + ) # {{{ deprecation handling diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 4c8a09a..63f9327 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -33,7 +33,6 @@ THE SOFTWARE. """ from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union -from warnings import warn import numpy as np diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 0000000..dbc725f --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,126 @@ +""" +.. currentmodule:: arraycontext + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Dict, Sequence, Union + +import numpy as np + +import loopy as lp +from pytools.tag import Tag + +from arraycontext.context import ArrayContext + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays. + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + + self.array_types = (np.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return np.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def from_numpy(self, np_array: np.ndarray): + # Uh oh... + return np_array + + def to_numpy(self, array): + # Uh oh... + return array + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + return array + + def thaw(self, array): + return array + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("NumpyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags: Union[Sequence[Tag], Tag], array): + # Numpy doesn't support tagging + return array + + def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 0000000..54867c8 --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,143 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np + +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + multimap_reduce_array_container, + rec_map_array_container, + rec_map_reduce_array_container, + rec_multimap_array_container, + rec_multimap_reduce_array_container, +) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, + BaseFakeNumpyNamespace, +) + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = frozenset({"concatenate", "reshape", "transpose", + "ones_like", "where", + *BaseFakeNumpyNamespace._numpy_math_functions + }) + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + raise AttributeError(name) + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y): + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a, b): + if type(a) != type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return np.all(np.equal(a, b)) + else: + return multimap_reduce_array_container(partial(reduce, + np.logical_and), + self.array_equal, a, b) + +# vim: fdm=marker diff --git a/doc/implementations.rst b/doc/implementations.rst index db35cca..4023e37 100644 --- a/doc/implementations.rst +++ b/doc/implementations.rst @@ -8,6 +8,11 @@ Implementations of the Array Context Abstraction ``` to update the coverage table below! +Array context based on :mod:`numpy` +-------------------------------------------- + +.. automodule:: arraycontext.impl.numpy + Array context based on :mod:`pyopencl.array` -------------------------------------------- -- GitLab