From 9ec29c59d7c59a8c4da4b3b656e9ad66ba7cee09 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 14:20:32 -0500 Subject: [PATCH] Improve, type, fix array_equal across all array contexts --- arraycontext/impl/jax/fake_numpy.py | 28 ++++++++++++------ arraycontext/impl/numpy/fake_numpy.py | 37 +++++++++++++++--------- arraycontext/impl/pyopencl/fake_numpy.py | 29 +++++++++++++------ arraycontext/impl/pytato/fake_numpy.py | 35 ++++++++++++++-------- 4 files changed, 86 insertions(+), 43 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 3fc5f2e..bc9481e 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -27,12 +27,16 @@ import numpy as np import jax.numpy as jnp -from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container import ( + NotAnArrayContainerError, + serialize_container, +) from arraycontext.container.traversal import ( rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace @@ -156,29 +160,35 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) def rec_equal(x, y): if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: if x.shape != y.shape: - return false + return false_ary else: return jnp.all(jnp.equal(x, y)) else: + if len(serialized_x) != len(serialized_y): + return false_ary return reduce( jnp.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) return rec_equal(a, b) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index b7a2335..b305717 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -25,14 +25,14 @@ from functools import partial, reduce import numpy as np -from arraycontext.container import is_array_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - multimap_reduce_array_container, rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import ( BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace, @@ -127,18 +127,29 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_reduce_array_container(partial(reduce, np.logical_and), lambda subary: np.all(subary), a) - def array_equal(self, a, b): - if type(a) != type(b): - return False - elif not is_array_container(a): - if a.shape != b.shape: - return False - else: - return np.all(np.equal(a, b)) + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: + false_ary = np.array(False) + true_ary = np.array(True) + if type(a) is not type(b): + return false_ary + + try: + serialized_x = serialize_container(a) + serialized_y = serialize_container(b) + except NotAnArrayContainerError: + assert isinstance(a, np.ndarray) + assert isinstance(b, np.ndarray) + return np.array(np.array_equal(a, b)) else: - return multimap_reduce_array_container(partial(reduce, - np.logical_and), - self.array_equal, a, b) + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( + np.logical_and, + [(true_ary if kx_i == ky_i else false_ary) + and self.array_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 59be99e..848870a 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -38,6 +38,7 @@ from arraycontext.container.traversal import ( rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -215,30 +216,40 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): result = result.get()[()] return result - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context queue = actx.queue # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, cl_array.Array) + assert isinstance(y, cl_array.Array) + if x.shape != y.shape: - return false + return false_ary else: return (x == y).all() else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( partial(cl_array.minimum, queue=queue), - [rec_equal(x_i, y_i)for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) result = rec_equal(a, b) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d3d018d..c6508e3 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -22,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from functools import partial, reduce -from typing import Any +from typing import Any, cast import numpy as np @@ -34,6 +34,7 @@ from arraycontext.container.traversal import ( rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -171,31 +172,41 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): partial(reduce, pt.logical_or), lambda subary: pt.any(subary), a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, pt.Array) + assert isinstance(y, pt.Array) + if x.shape != y.shape: - return false + return false_ary else: - return pt.all(pt.equal(x, y)) + return pt.all(cast(pt.Array, pt.equal(x, y))) else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( pt.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) - return rec_equal(a, b) + return cast(Array, rec_equal(a, b)) # }}} -- GitLab