From 5bf46bc61c3325981852244130544561c26be139 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sun, 26 Sep 2021 02:38:28 -0500
Subject: [PATCH] Implements NumpyArrayContext

---
 arraycontext/__init__.py              |   7 +-
 arraycontext/container/arithmetic.py  |   1 -
 arraycontext/impl/numpy/__init__.py   | 126 +++++++++++++++++++++++
 arraycontext/impl/numpy/fake_numpy.py | 143 ++++++++++++++++++++++++++
 doc/implementations.rst               |   5 +
 5 files changed, 278 insertions(+), 4 deletions(-)
 create mode 100644 arraycontext/impl/numpy/__init__.py
 create mode 100644 arraycontext/impl/numpy/fake_numpy.py

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 1d0efb3..2f2640d 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -78,6 +78,7 @@ from .context import (
     tag_axes,
 )
 from .impl.jax import EagerJAXArrayContext
+from .impl.numpy import NumpyArrayContext
 from .impl.pyopencl import PyOpenCLArrayContext
 from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
 from .loopy import make_loopy_program
@@ -91,7 +92,6 @@ from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
 
 
 __all__ = (
-    "Array",
     "Array",
     "ArrayContainer",
     "ArrayContainerT",
@@ -105,13 +105,13 @@ __all__ = (
     "EagerJAXArrayContext",
     "ElementwiseMapKernelTag",
     "NotAnArrayContainerError",
+    "NumpyArrayContext",
     "PyOpenCLArrayContext",
     "PytatoJAXArrayContext",
     "PytatoPyOpenCLArrayContext",
     "PytestArrayContextFactory",
     "PytestPyOpenCLArrayContextFactory",
     "Scalar",
-    "Scalar",
     "ScalarLike",
     "dataclass_array_container",
     "deserialize_container",
@@ -146,8 +146,9 @@ __all__ = (
     "to_numpy",
     "unflatten",
     "with_array_context",
+    "with_container_arithmetic",
     "with_container_arithmetic"
-)
+    )
 
 
 # {{{ deprecation handling
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 4c8a09a..63f9327 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -33,7 +33,6 @@ THE SOFTWARE.
 """
 
 from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union
-from warnings import warn
 
 import numpy as np
 
diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py
new file mode 100644
index 0000000..dbc725f
--- /dev/null
+++ b/arraycontext/impl/numpy/__init__.py
@@ -0,0 +1,126 @@
+"""
+.. currentmodule:: arraycontext
+
+A mod :`numpy`-based array context.
+
+.. autoclass:: NumpyArrayContext
+"""
+
+__copyright__ = """
+Copyright (C) 2021 University of Illinois Board of Trustees
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+from typing import Dict, Sequence, Union
+
+import numpy as np
+
+import loopy as lp
+from pytools.tag import Tag
+
+from arraycontext.context import ArrayContext
+
+
+class NumpyArrayContext(ArrayContext):
+    """
+    A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
+
+    .. automethod:: __init__
+    """
+    def __init__(self):
+        super().__init__()
+        self._loopy_transform_cache: \
+                Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
+
+        self.array_types = (np.ndarray,)
+
+    def _get_fake_numpy_namespace(self):
+        from .fake_numpy import NumpyFakeNumpyNamespace
+        return NumpyFakeNumpyNamespace(self)
+
+    # {{{ ArrayContext interface
+
+    def clone(self):
+        return type(self)()
+
+    def empty(self, shape, dtype):
+        return np.empty(shape, dtype=dtype)
+
+    def zeros(self, shape, dtype):
+        return np.zeros(shape, dtype)
+
+    def from_numpy(self, np_array: np.ndarray):
+        # Uh oh...
+        return np_array
+
+    def to_numpy(self, array):
+        # Uh oh...
+        return array
+
+    def call_loopy(self, t_unit, **kwargs):
+        t_unit = t_unit.copy(target=lp.ExecutableCTarget())
+        try:
+            t_unit = self._loopy_transform_cache[t_unit]
+        except KeyError:
+            orig_t_unit = t_unit
+            t_unit = self.transform_loopy_program(t_unit)
+            self._loopy_transform_cache[orig_t_unit] = t_unit
+            del orig_t_unit
+
+        _, result = t_unit(**kwargs)
+
+        return result
+
+    def freeze(self, array):
+        return array
+
+    def thaw(self, array):
+        return array
+
+    # }}}
+
+    def transform_loopy_program(self, t_unit):
+        raise ValueError("NumpyArrayContext does not implement "
+                         "transform_loopy_program. Sub-classes are supposed "
+                         "to implement it.")
+
+    def tag(self, tags: Union[Sequence[Tag], Tag], array):
+        # Numpy doesn't support tagging
+        return array
+
+    def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
+        return array
+
+    def einsum(self, spec, *args, arg_names=None, tagged=()):
+        return np.einsum(spec, *args)
+
+    @property
+    def permits_inplace_modification(self):
+        return True
+
+    @property
+    def supports_nonscalar_broadcasting(self):
+        return True
+
+    @property
+    def permits_advanced_indexing(self):
+        return True
diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py
new file mode 100644
index 0000000..54867c8
--- /dev/null
+++ b/arraycontext/impl/numpy/fake_numpy.py
@@ -0,0 +1,143 @@
+__copyright__ = """
+Copyright (C) 2021 University of Illinois Board of Trustees
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+from functools import partial, reduce
+
+import numpy as np
+
+from arraycontext.container import is_array_container
+from arraycontext.container.traversal import (
+    multimap_reduce_array_container,
+    rec_map_array_container,
+    rec_map_reduce_array_container,
+    rec_multimap_array_container,
+    rec_multimap_reduce_array_container,
+)
+from arraycontext.fake_numpy import (
+    BaseFakeNumpyLinalgNamespace,
+    BaseFakeNumpyNamespace,
+)
+
+
+class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
+    # Everything is implemented in the base class for now.
+    pass
+
+
+_NUMPY_UFUNCS = frozenset({"concatenate", "reshape", "transpose",
+                 "ones_like", "where",
+                 *BaseFakeNumpyNamespace._numpy_math_functions
+                 })
+
+
+class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
+    """
+    A :mod:`numpy` mimic for :class:`NumpyArrayContext`.
+    """
+    def _get_fake_numpy_linalg_namespace(self):
+        return NumpyFakeNumpyLinalgNamespace(self._array_context)
+
+    def zeros(self, shape, dtype):
+        return np.zeros(shape, dtype)
+
+    def __getattr__(self, name):
+
+        if name in _NUMPY_UFUNCS:
+            from functools import partial
+            return partial(rec_multimap_array_container,
+                           getattr(np, name))
+
+        raise AttributeError(name)
+
+    def sum(self, a, axis=None, dtype=None):
+        return rec_map_reduce_array_container(sum, partial(np.sum,
+                                                           axis=axis,
+                                                           dtype=dtype),
+                                              a)
+
+    def min(self, a, axis=None):
+        return rec_map_reduce_array_container(
+                partial(reduce, np.minimum), partial(np.amin, axis=axis), a)
+
+    def max(self, a, axis=None):
+        return rec_map_reduce_array_container(
+                partial(reduce, np.maximum), partial(np.amax, axis=axis), a)
+
+    def stack(self, arrays, axis=0):
+        return rec_multimap_array_container(
+                lambda *args: np.stack(arrays=args, axis=axis),
+                *arrays)
+
+    def broadcast_to(self, array, shape):
+        return rec_map_array_container(partial(np.broadcast_to, shape=shape), array)
+
+    # {{{ relational operators
+
+    def equal(self, x, y):
+        return rec_multimap_array_container(np.equal, x, y)
+
+    def not_equal(self, x, y):
+        return rec_multimap_array_container(np.not_equal, x, y)
+
+    def greater(self, x, y):
+        return rec_multimap_array_container(np.greater, x, y)
+
+    def greater_equal(self, x, y):
+        return rec_multimap_array_container(np.greater_equal, x, y)
+
+    def less(self, x, y):
+        return rec_multimap_array_container(np.less, x, y)
+
+    def less_equal(self, x, y):
+        return rec_multimap_array_container(np.less_equal, x, y)
+
+    # }}}
+
+    def ravel(self, a, order="C"):
+        return rec_map_array_container(partial(np.ravel, order=order), a)
+
+    def vdot(self, x, y):
+        return rec_multimap_reduce_array_container(sum, np.vdot, x, y)
+
+    def any(self, a):
+        return rec_map_reduce_array_container(partial(reduce, np.logical_or),
+                                              lambda subary: np.any(subary), a)
+
+    def all(self, a):
+        return rec_map_reduce_array_container(partial(reduce, np.logical_and),
+                                              lambda subary: np.all(subary), a)
+
+    def array_equal(self, a, b):
+        if type(a) != type(b):
+            return False
+        elif not is_array_container(a):
+            if a.shape != b.shape:
+                return False
+            else:
+                return np.all(np.equal(a, b))
+        else:
+            return multimap_reduce_array_container(partial(reduce,
+                                                           np.logical_and),
+                                                   self.array_equal, a, b)
+
+# vim: fdm=marker
diff --git a/doc/implementations.rst b/doc/implementations.rst
index db35cca..4023e37 100644
--- a/doc/implementations.rst
+++ b/doc/implementations.rst
@@ -8,6 +8,11 @@ Implementations of the Array Context Abstraction
     ```
     to update the coverage table below!
 
+Array context based on :mod:`numpy`
+--------------------------------------------
+
+.. automodule:: arraycontext.impl.numpy
+
 Array context based on :mod:`pyopencl.array`
 --------------------------------------------
 
-- 
GitLab