From 958d28180c53e85072fb21249bea1113409fd4f2 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 29 Jun 2022 20:10:46 +0300
Subject: [PATCH] deprecate empty / empty_like / zeros_like

---
 arraycontext/context.py                  | 10 +++++++
 arraycontext/fake_numpy.py               |  6 ----
 arraycontext/impl/jax/__init__.py        | 18 ++++++++++--
 arraycontext/impl/jax/fake_numpy.py      | 19 ++++++++++++
 arraycontext/impl/pyopencl/__init__.py   | 21 ++++++++++----
 arraycontext/impl/pyopencl/fake_numpy.py | 37 ++++++++++++++++++++----
 arraycontext/impl/pytato/__init__.py     | 16 ++++++----
 arraycontext/impl/pytato/fake_numpy.py   |  7 +++++
 8 files changed, 108 insertions(+), 26 deletions(-)

diff --git a/arraycontext/context.py b/arraycontext/context.py
index 36a7ace..e152838 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -299,9 +299,19 @@ class ArrayContext(ABC):
         pass
 
     def empty_like(self, ary: Array) -> Array:
+        from warnings import warn
+        warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
+            "working in 2023. Prefer actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
+
         return self.empty(shape=ary.shape, dtype=ary.dtype)
 
     def zeros_like(self, ary: Array) -> Array:
+        from warnings import warn
+        warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
+            "working in 2023. Use actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
+
         return self.zeros(shape=ary.shape, dtype=ary.dtype)
 
     @abstractmethod
diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index d5c8fce..c3e37f8 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -91,12 +91,6 @@ class BaseFakeNumpyNamespace:
         # "interp",
         })
 
-    def empty_like(self, ary):
-        return self._array_context.empty_like(ary)
-
-    def zeros_like(self, ary):
-        return self._array_context.zeros_like(ary)
-
     def conjugate(self, x):
         # NOTE: conjugate distributes over object arrays, but it looks for a
         # `conjugate` ufunc, while some implementations only have the shorter
diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
index dfb89c4..f4794e4 100644
--- a/arraycontext/impl/jax/__init__.py
+++ b/arraycontext/impl/jax/__init__.py
@@ -88,6 +88,11 @@ class EagerJAXArrayContext(ArrayContext):
     # {{{ ArrayContext interface
 
     def empty(self, shape, dtype):
+        from warnings import warn
+        warn(f"{type(self).__name__}.empty is deprecated and will stop "
+            "working in 2023. Prefer actx.zeros instead.",
+            DeprecationWarning, stacklevel=2)
+
         import jax.numpy as jnp
         return jnp.empty(shape=shape, dtype=dtype)
 
@@ -96,16 +101,23 @@ class EagerJAXArrayContext(ArrayContext):
         return jnp.zeros(shape=shape, dtype=dtype)
 
     def empty_like(self, ary):
+        from warnings import warn
+        warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
+            "working in 2023. Prefer actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
+
         def _empty_like(array):
             return self.empty(array.shape, array.dtype)
 
         return self._rec_map_container(_empty_like, ary)
 
     def zeros_like(self, ary):
-        def _zeros_like(array):
-            return self.zeros(array.shape, array.dtype)
+        from warnings import warn
+        warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
+            "working in 2023. Use actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
 
-        return self._rec_map_container(_zeros_like, ary, default_scalar=0)
+        return self.np.zeros_like(ary)
 
     def from_numpy(self, array):
         def _from_numpy(ary):
diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index 37c99b4..daaf880 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -56,6 +56,25 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
     # {{{ array creation routines
 
+    def empty_like(self, ary):
+        from warnings import warn
+        warn(f"{type(self._array_context).__name__}.np.empty_like is "
+            "deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
+            "instead.",
+            DeprecationWarning, stacklevel=2)
+
+        def _empty_like(array):
+            return self._array_context.empty(array.shape, array.dtype)
+
+        return self._array_context._rec_map_container(_empty_like, ary)
+
+    def zeros_like(self, ary):
+        def _zeros_like(array):
+            return self._array_context.zeros(array.shape, array.dtype)
+
+        return self._array_context._rec_map_container(
+            _zeros_like, ary, default_scalar=0)
+
     def ones_like(self, ary):
         return self.full_like(ary, 1)
 
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 71b04c7..aced309 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -189,6 +189,11 @@ class PyOpenCLArrayContext(ArrayContext):
     # {{{ ArrayContext interface
 
     def empty(self, shape, dtype):
+        from warnings import warn
+        warn(f"{type(self).__name__}.empty is deprecated and will stop "
+            "working in 2023. Prefer actx.zeros instead.",
+            DeprecationWarning, stacklevel=2)
+
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
         return tga.empty(self.queue, shape, dtype, allocator=self.allocator)
 
@@ -197,6 +202,11 @@ class PyOpenCLArrayContext(ArrayContext):
         return tga.zeros(self.queue, shape, dtype, allocator=self.allocator)
 
     def empty_like(self, ary):
+        from warnings import warn
+        warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
+            "working in 2023. Prefer actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
+
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         def _empty_like(array):
@@ -206,13 +216,12 @@ class PyOpenCLArrayContext(ArrayContext):
         return self._rec_map_container(_empty_like, ary)
 
     def zeros_like(self, ary):
-        import arraycontext.impl.pyopencl.taggable_cl_array as tga
-
-        def _zeros_like(array):
-            return tga.zeros(self.queue, array.shape, array.dtype,
-                allocator=self.allocator, axes=array.axes, tags=array.tags)
+        from warnings import warn
+        warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
+            "working in 2023. Use actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
 
-        return self._rec_map_container(_zeros_like, ary, default_scalar=0)
+        return self.np.zeros_like(ary)
 
     def from_numpy(self, array):
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 2e206a8..a0180e7 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -63,22 +63,49 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     # {{{ array creation routines
 
+    def empty_like(self, ary):
+        from warnings import warn
+        warn(f"{type(self._array_context).__name__}.np.empty_like is "
+            "deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
+            "instead.",
+            DeprecationWarning, stacklevel=2)
+
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        actx = self._array_context
+
+        def _empty_like(array):
+            return tga.empty(actx.queue, array.shape, array.dtype,
+                allocator=actx.allocator, axes=array.axes, tags=array.tags)
+
+        return actx._rec_map_container(_empty_like, ary)
+
+    def zeros_like(self, ary):
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        actx = self._array_context
+
+        def _zeros_like(array):
+            return tga.zeros(
+                actx.queue, array.shape, array.dtype,
+                allocator=actx.allocator, axes=array.axes, tags=array.tags)
+
+        return actx._rec_map_container(_zeros_like, ary, default_scalar=0)
+
     def ones_like(self, ary):
         return self.full_like(ary, 1)
 
     def full_like(self, ary, fill_value):
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        actx = self._array_context
 
         def _full_like(subary):
             filled = tga.empty(
-                self._array_context.queue, subary.shape, subary.dtype,
-                allocator=self._array_context.allocator,
-                axes=subary.axes, tags=subary.tags)
+                actx.queue, subary.shape, subary.dtype,
+                allocator=actx.allocator, axes=subary.axes, tags=subary.tags)
             filled.fill(fill_value)
+
             return filled
 
-        return self._array_context._rec_map_container(
-            _full_like, ary, default_scalar=fill_value)
+        return actx._rec_map_container(_full_like, ary, default_scalar=fill_value)
 
     def copy(self, ary):
         def _copy(subary):
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 6b9ac6b..8ccc768 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -348,10 +348,12 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
     # {{{ ArrayContext interface
 
     def zeros_like(self, ary):
-        def _zeros_like(array):
-            return self.zeros(array.shape, array.dtype)
+        from warnings import warn
+        warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
+            "working in 2023. Use actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
 
-        return self._rec_map_container(_zeros_like, ary, default_scalar=0)
+        return self.np.zeros_like(ary)
 
     def from_numpy(self, array):
         import pytato as pt
@@ -720,10 +722,12 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
     # {{{ ArrayContext interface
 
     def zeros_like(self, ary):
-        def _zeros_like(array):
-            return self.zeros(array.shape, array.dtype)
+        from warnings import warn
+        warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
+            "working in 2023. Use actx.np.zeros_like instead.",
+            DeprecationWarning, stacklevel=2)
 
-        return self._rec_map_container(_zeros_like, ary, default_scalar=0)
+        return self.np.zeros_like(ary)
 
     def from_numpy(self, array):
         import jax
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index d1890f2..e17f8ee 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -76,6 +76,13 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     # {{{ array creation routines
 
+    def zeros_like(self, ary):
+        def _zeros_like(array):
+            return self._array_context.zeros(array.shape, array.dtype)
+
+        return self._array_context._rec_map_container(
+            _zeros_like, ary, default_scalar=0)
+
     def ones_like(self, ary):
         return self.full_like(ary, 1)
 
-- 
GitLab