From 5c9d57ab40e5c5089ddddf120a400283304db1f5 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sun, 26 Jun 2022 11:23:24 +0300
Subject: [PATCH] forward actx.empty_like in actx.np.empty_like

---
 arraycontext/fake_numpy.py               | 18 ++-----------
 arraycontext/impl/jax/fake_numpy.py      | 21 ++++++++--------
 arraycontext/impl/pyopencl/fake_numpy.py | 32 ++++++++++++++++++------
 arraycontext/impl/pytato/fake_numpy.py   |  3 ++-
 4 files changed, 40 insertions(+), 34 deletions(-)

diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index 73c9e40..d5c8fce 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -91,25 +91,11 @@ class BaseFakeNumpyNamespace:
         # "interp",
         })
 
-    def _new_like(self, ary, alloc_like):
-        if np.isscalar(ary):
-            # NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
-            # is best implemented by concrete array contexts, if at all
-            raise NotImplementedError("operation not implemented for scalars")
-
-        if isinstance(ary, np.ndarray) and ary.dtype.char == "O":
-            # NOTE: we don't want to match numpy semantics on object arrays,
-            # e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
-            # FIXME: what about object arrays nested in an ArrayContainer?
-            raise NotImplementedError("operation not implemented for object arrays")
-
-        return rec_map_array_container(alloc_like, ary)
-
     def empty_like(self, ary):
-        return self._new_like(ary, self._array_context.empty_like)
+        return self._array_context.empty_like(ary)
 
     def zeros_like(self, ary):
-        return self._new_like(ary, self._array_context.zeros_like)
+        return self._array_context.zeros_like(ary)
 
     def conjugate(self, x):
         # NOTE: conjugate distributes over object arrays, but it looks for a
diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index 6953afb..8e0308c 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -23,6 +23,9 @@ THE SOFTWARE.
 """
 from functools import partial, reduce
 
+import numpy as np
+import jax.numpy as jnp
+
 from arraycontext.fake_numpy import (
         BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
         )
@@ -31,8 +34,6 @@ from arraycontext.container.traversal import (
         rec_map_reduce_array_container,
         )
 from arraycontext.container import NotAnArrayContainerError, serialize_container
-import numpy
-import jax.numpy as jnp
 
 
 class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
@@ -62,7 +63,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
         def _full_like(subary):
             return jnp.full_like(subary, fill_value)
 
-        return self._new_like(ary, _full_like)
+        return self._array_context._rec_map_container(
+            _full_like, ary, default_scalar=fill_value)
 
     # }}}
 
@@ -111,11 +113,10 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
         from arraycontext import rec_multimap_reduce_array_container
 
         def _rec_vdot(ary1, ary2):
-            if dtype not in [None, numpy.find_common_type((ary1.dtype,
-                                                           ary2.dtype),
-                                                          ())]:
-                raise NotImplementedError(f"{type(self)} cannot take dtype in"
-                                          " vdot.")
+            common_dtype = np.find_common_type((ary1.dtype, ary2.dtype), ())
+            if dtype not in [None, common_dtype]:
+                raise NotImplementedError(
+                    f"{type(self).__name__} cannot take dtype in vdot.")
 
             return jnp.vdot(ary1, ary2)
 
@@ -129,8 +130,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
         actx = self._array_context
 
         # NOTE: not all backends support `bool` properly, so use `int8` instead
-        true = actx.from_numpy(numpy.int8(True))
-        false = actx.from_numpy(numpy.int8(False))
+        true = actx.from_numpy(np.int8(True))
+        false = actx.from_numpy(np.int8(False))
 
         def rec_equal(x, y):
             if type(x) != type(y):
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 3c9be87..2e206a8 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -67,18 +67,24 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
         return self.full_like(ary, 1)
 
     def full_like(self, ary, fill_value):
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+
         def _full_like(subary):
-            ones = self._array_context.empty_like(subary)
-            ones.fill(fill_value)
-            return ones
+            filled = tga.empty(
+                self._array_context.queue, subary.shape, subary.dtype,
+                allocator=self._array_context.allocator,
+                axes=subary.axes, tags=subary.tags)
+            filled.fill(fill_value)
+            return filled
 
-        return self._new_like(ary, _full_like)
+        return self._array_context._rec_map_container(
+            _full_like, ary, default_scalar=fill_value)
 
     def copy(self, ary):
         def _copy(subary):
             return subary.copy(queue=self._array_context.queue)
 
-        return self._new_like(ary, _copy)
+        return self._array_context._rec_map_container(_copy, ary)
 
     # }}}
 
@@ -144,9 +150,15 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     def all(self, a):
         queue = self._array_context.queue
+
+        def _all(ary):
+            if np.isscalar(ary):
+                return np.int8(all([ary]))
+            return ary.all(queue=queue)
+
         result = rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.minimum, queue=queue)),
-                lambda subary: subary.all(queue=queue),
+                _all,
                 a)
 
         if not self._array_context._force_device_scalars:
@@ -155,9 +167,15 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     def any(self, a):
         queue = self._array_context.queue
+
+        def _any(ary):
+            if np.isscalar(ary):
+                return np.int8(any([ary]))
+            return ary.any(queue=queue)
+
         result = rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.maximum, queue=queue)),
-                lambda subary: subary.any(queue=queue),
+                _any,
                 a)
 
         if not self._array_context._force_device_scalars:
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index 0f219d1..d1890f2 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -83,7 +83,8 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
         def _full_like(subary):
             return pt.full(subary.shape, fill_value, subary.dtype)
 
-        return self._new_like(ary, _full_like)
+        return self._array_context._rec_map_container(
+            _full_like, ary, default_scalar=fill_value)
 
     # }}}
 
-- 
GitLab