From 3f8d0f6209e6c6ab4452e0854ce732d7717c641a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 5 Aug 2024 12:42:46 -0500
Subject: [PATCH] Implement actx.np.zeros

---
 arraycontext/context.py                  | 18 ++++++++----------
 arraycontext/fake_numpy.py               | 11 ++++++++++-
 arraycontext/impl/jax/__init__.py        |  2 +-
 arraycontext/impl/jax/fake_numpy.py      |  3 +++
 arraycontext/impl/pyopencl/__init__.py   |  2 +-
 arraycontext/impl/pyopencl/fake_numpy.py |  6 ++++++
 arraycontext/impl/pytato/fake_numpy.py   |  3 +++
 test/test_arraycontext.py                | 12 ++++++------
 8 files changed, 38 insertions(+), 19 deletions(-)

diff --git a/arraycontext/context.py b/arraycontext/context.py
index 38c52dd..8b42bca 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -171,6 +171,7 @@ from typing import (
     TypeVar,
     Union,
 )
+from warnings import warn
 
 import numpy as np
 
@@ -249,10 +250,6 @@ class ArrayContext(ABC):
 
     .. versionadded:: 2020.2
 
-    .. automethod:: empty
-    .. automethod:: zeros
-    .. automethod:: empty_like
-    .. automethod:: zeros_like
     .. automethod:: from_numpy
     .. automethod:: to_numpy
     .. automethod:: call_loopy
@@ -293,9 +290,9 @@ class ArrayContext(ABC):
     def __init__(self) -> None:
         self.np = self._get_fake_numpy_namespace()
 
+    @abstractmethod
     def _get_fake_numpy_namespace(self) -> Any:
-        from .fake_numpy import BaseFakeNumpyNamespace
-        return BaseFakeNumpyNamespace(self)
+        ...
 
     def __hash__(self) -> int:
         raise TypeError(f"unhashable type: '{type(self).__name__}'")
@@ -306,14 +303,16 @@ class ArrayContext(ABC):
               dtype: "np.dtype[Any]") -> Array:
         pass
 
-    @abstractmethod
     def zeros(self,
               shape: Union[int, Tuple[int, ...]],
               dtype: "np.dtype[Any]") -> Array:
-        pass
+        warn(f"{type(self).__name__}.zeros is deprecated and will stop "
+            "working in 2025. Use actx.np.zeros instead.",
+            DeprecationWarning, stacklevel=2)
+
+        return self.np.zeros(shape, dtype)
 
     def empty_like(self, ary: Array) -> Array:
-        from warnings import warn
         warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
             "working in 2023. Prefer actx.np.zeros_like instead.",
             DeprecationWarning, stacklevel=2)
@@ -321,7 +320,6 @@ class ArrayContext(ABC):
         return self.empty(shape=ary.shape, dtype=ary.dtype)
 
     def zeros_like(self, ary: Array) -> Array:
-        from warnings import warn
         warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
             "working in 2023. Use actx.np.zeros_like instead.",
             DeprecationWarning, stacklevel=2)
diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index e31bae7..8473cc4 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -24,6 +24,7 @@ THE SOFTWARE.
 
 
 import operator
+from abc import ABC, abstractmethod
 from typing import Any
 
 import numpy as np
@@ -34,7 +35,7 @@ from arraycontext.container.traversal import rec_map_array_container
 
 # {{{ BaseFakeNumpyNamespace
 
-class BaseFakeNumpyNamespace:
+class BaseFakeNumpyNamespace(ABC):
     def __init__(self, array_context):
         self._array_context = array_context
         self.linalg = self._get_fake_numpy_linalg_namespace()
@@ -95,6 +96,14 @@ class BaseFakeNumpyNamespace:
         # "interp",
         })
 
+    @abstractmethod
+    def zeros(self, shape, dtype):
+        ...
+
+    @abstractmethod
+    def zeros_like(self, ary):
+        ...
+
     def conjugate(self, x):
         # NOTE: conjugate distributes over object arrays, but it looks for a
         # `conjugate` ufunc, while some implementations only have the shorter
diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
index c52b24a..0304541 100644
--- a/arraycontext/impl/jax/__init__.py
+++ b/arraycontext/impl/jax/__init__.py
@@ -90,7 +90,7 @@ class EagerJAXArrayContext(ArrayContext):
     def empty(self, shape, dtype):
         from warnings import warn
         warn(f"{type(self).__name__}.empty is deprecated and will stop "
-            "working in 2023. Prefer actx.zeros instead.",
+            "working in 2023. Prefer actx.np.zeros instead.",
             DeprecationWarning, stacklevel=2)
 
         import jax.numpy as jnp
diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index afe6728..3fc5f2e 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -56,6 +56,9 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
     # {{{ array creation routines
 
+    def zeros(self, shape, dtype):
+        return jnp.zeros(shape=shape, dtype=dtype)
+
     def empty_like(self, ary):
         from warnings import warn
         warn(f"{type(self._array_context).__name__}.np.empty_like is "
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 990a422..9be77a4 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -201,7 +201,7 @@ class PyOpenCLArrayContext(ArrayContext):
     def empty(self, shape, dtype):
         from warnings import warn
         warn(f"{type(self).__name__}.empty is deprecated and will stop "
-            "working in 2023. Prefer actx.zeros instead.",
+            "working in 2023. Prefer actx.np.zeros instead.",
             DeprecationWarning, stacklevel=2)
 
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 2583bfa..59be99e 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -39,6 +39,7 @@ from arraycontext.container.traversal import (
     rec_multimap_reduce_array_container,
 )
 from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
+from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
 from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
 
 
@@ -60,6 +61,11 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     # {{{ array creation routines
 
+    def zeros(self, shape, dtype) -> TaggableCLArray:
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        return tga.zeros(self._array_context.queue, shape, dtype,
+                         allocator=self._array_context.allocator)
+
     def empty_like(self, ary):
         from warnings import warn
         warn(f"{type(self._array_context).__name__}.np.empty_like is "
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index 9c41b52..d3d018d 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -84,6 +84,9 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     # {{{ array creation routines
 
+    def zeros(self, shape, dtype):
+        return pt.zeros(shape, dtype)
+
     def zeros_like(self, ary):
         def _zeros_like(array):
             return self._array_context.zeros(
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index fb16b87..3f06156 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -1367,7 +1367,7 @@ def test_leaf_array_type_broadcasting(actx_factory):
     # test support for https://github.com/inducer/arraycontext/issues/49
     actx = actx_factory()
 
-    foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, )))
+    foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )))
     bar = foo + 4
     baz = foo + actx.from_numpy(4*np.ones((3, )))
     qux = actx.from_numpy(4*np.ones((3, ))) + foo
@@ -1510,7 +1510,7 @@ def test_actx_compile_on_pure_array_return(actx_factory):
 
     actx = actx_factory()
     ones = actx.thaw(actx.freeze(
-        actx.zeros(shape=(10, 4), dtype=np.float64) + 1
+        actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
         ))
     np.testing.assert_allclose(actx.to_numpy(_twice(ones)),
                                actx.to_numpy(actx.compile(_twice)(ones)))
@@ -1573,7 +1573,7 @@ def test_taggable_cl_array_tags(actx_factory):
 def test_to_numpy_on_frozen_arrays(actx_factory):
     # See https://github.com/inducer/arraycontext/issues/159
     actx = actx_factory()
-    u = actx.freeze(actx.zeros(10, dtype="float64")+1)
+    u = actx.freeze(actx.np.zeros(10, dtype="float64")+1)
     np.testing.assert_allclose(actx.to_numpy(u), 1)
     np.testing.assert_allclose(actx.to_numpy(u), 1)
 
@@ -1592,7 +1592,7 @@ def test_tagging(actx_factory):
     ary = tag_axes(actx, {0: ExampleTag()},
             actx.tag(
                 ExampleTag(),
-                actx.zeros((20, 20), dtype=np.float64)))
+                actx.np.zeros((20, 20), dtype=np.float64)))
 
     assert ary.tags_of_type(ExampleTag)
     assert ary.axes[0].tags_of_type(ExampleTag)
@@ -1606,11 +1606,11 @@ def test_compile_anonymous_function(actx_factory):
     actx = actx_factory()
     f = actx.compile(lambda x: 2*x+40)
     np.testing.assert_allclose(
-        actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
+        actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
         42)
     f = actx.compile(partial(lambda x: 2*x+40))
     np.testing.assert_allclose(
-        actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
+        actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
         42)
 
 
-- 
GitLab