From 5c9d57ab40e5c5089ddddf120a400283304db1f5 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 26 Jun 2022 11:23:24 +0300 Subject: [PATCH] forward actx.empty_like in actx.np.empty_like --- arraycontext/fake_numpy.py | 18 ++----------- arraycontext/impl/jax/fake_numpy.py | 21 ++++++++-------- arraycontext/impl/pyopencl/fake_numpy.py | 32 ++++++++++++++++++------ arraycontext/impl/pytato/fake_numpy.py | 3 ++- 4 files changed, 40 insertions(+), 34 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 73c9e40..d5c8fce 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -91,25 +91,11 @@ class BaseFakeNumpyNamespace: # "interp", }) - def _new_like(self, ary, alloc_like): - if np.isscalar(ary): - # NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which - # is best implemented by concrete array contexts, if at all - raise NotImplementedError("operation not implemented for scalars") - - if isinstance(ary, np.ndarray) and ary.dtype.char == "O": - # NOTE: we don't want to match numpy semantics on object arrays, - # e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)` - # FIXME: what about object arrays nested in an ArrayContainer? - raise NotImplementedError("operation not implemented for object arrays") - - return rec_map_array_container(alloc_like, ary) - def empty_like(self, ary): - return self._new_like(ary, self._array_context.empty_like) + return self._array_context.empty_like(ary) def zeros_like(self, ary): - return self._new_like(ary, self._array_context.zeros_like) + return self._array_context.zeros_like(ary) def conjugate(self, x): # NOTE: conjugate distributes over object arrays, but it looks for a diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 6953afb..8e0308c 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -23,6 +23,9 @@ THE SOFTWARE. """ from functools import partial, reduce +import numpy as np +import jax.numpy as jnp + from arraycontext.fake_numpy import ( BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, ) @@ -31,8 +34,6 @@ from arraycontext.container.traversal import ( rec_map_reduce_array_container, ) from arraycontext.container import NotAnArrayContainerError, serialize_container -import numpy -import jax.numpy as jnp class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): @@ -62,7 +63,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): def _full_like(subary): return jnp.full_like(subary, fill_value) - return self._new_like(ary, _full_like) + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) # }}} @@ -111,11 +113,10 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): from arraycontext import rec_multimap_reduce_array_container def _rec_vdot(ary1, ary2): - if dtype not in [None, numpy.find_common_type((ary1.dtype, - ary2.dtype), - ())]: - raise NotImplementedError(f"{type(self)} cannot take dtype in" - " vdot.") + common_dtype = np.find_common_type((ary1.dtype, ary2.dtype), ()) + if dtype not in [None, common_dtype]: + raise NotImplementedError( + f"{type(self).__name__} cannot take dtype in vdot.") return jnp.vdot(ary1, ary2) @@ -129,8 +130,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(numpy.int8(True)) - false = actx.from_numpy(numpy.int8(False)) + true = actx.from_numpy(np.int8(True)) + false = actx.from_numpy(np.int8(False)) def rec_equal(x, y): if type(x) != type(y): diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 3c9be87..2e206a8 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -67,18 +67,24 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): return self.full_like(ary, 1) def full_like(self, ary, fill_value): + import arraycontext.impl.pyopencl.taggable_cl_array as tga + def _full_like(subary): - ones = self._array_context.empty_like(subary) - ones.fill(fill_value) - return ones + filled = tga.empty( + self._array_context.queue, subary.shape, subary.dtype, + allocator=self._array_context.allocator, + axes=subary.axes, tags=subary.tags) + filled.fill(fill_value) + return filled - return self._new_like(ary, _full_like) + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) def copy(self, ary): def _copy(subary): return subary.copy(queue=self._array_context.queue) - return self._new_like(ary, _copy) + return self._array_context._rec_map_container(_copy, ary) # }}} @@ -144,9 +150,15 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def all(self, a): queue = self._array_context.queue + + def _all(ary): + if np.isscalar(ary): + return np.int8(all([ary])) + return ary.all(queue=queue) + result = rec_map_reduce_array_container( partial(reduce, partial(cl_array.minimum, queue=queue)), - lambda subary: subary.all(queue=queue), + _all, a) if not self._array_context._force_device_scalars: @@ -155,9 +167,15 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def any(self, a): queue = self._array_context.queue + + def _any(ary): + if np.isscalar(ary): + return np.int8(any([ary])) + return ary.any(queue=queue) + result = rec_map_reduce_array_container( partial(reduce, partial(cl_array.maximum, queue=queue)), - lambda subary: subary.any(queue=queue), + _any, a) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 0f219d1..d1890f2 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -83,7 +83,8 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def _full_like(subary): return pt.full(subary.shape, fill_value, subary.dtype) - return self._new_like(ary, _full_like) + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) # }}} -- GitLab