diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 3fc5f2e6eaee36eef7c4a3802df81fd1f4068fa5..bc9481e36067ab07300bf0c10431eed2fac889d0 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -27,12 +27,16 @@ import numpy as np import jax.numpy as jnp -from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container import ( + NotAnArrayContainerError, + serialize_container, +) from arraycontext.container.traversal import ( rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace @@ -156,29 +160,35 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) def rec_equal(x, y): if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: if x.shape != y.shape: - return false + return false_ary else: return jnp.all(jnp.equal(x, y)) else: + if len(serialized_x) != len(serialized_y): + return false_ary return reduce( jnp.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) return rec_equal(a, b) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index b7a2335a1a50ca58a235e180b66c264e619ebabd..b305717e15af76e05e4ea6cec909875d1061e6b4 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -25,14 +25,14 @@ from functools import partial, reduce import numpy as np -from arraycontext.container import is_array_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - multimap_reduce_array_container, rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import ( BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace, @@ -127,18 +127,29 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_reduce_array_container(partial(reduce, np.logical_and), lambda subary: np.all(subary), a) - def array_equal(self, a, b): - if type(a) != type(b): - return False - elif not is_array_container(a): - if a.shape != b.shape: - return False - else: - return np.all(np.equal(a, b)) + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: + false_ary = np.array(False) + true_ary = np.array(True) + if type(a) is not type(b): + return false_ary + + try: + serialized_x = serialize_container(a) + serialized_y = serialize_container(b) + except NotAnArrayContainerError: + assert isinstance(a, np.ndarray) + assert isinstance(b, np.ndarray) + return np.array(np.array_equal(a, b)) else: - return multimap_reduce_array_container(partial(reduce, - np.logical_and), - self.array_equal, a, b) + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( + np.logical_and, + [(true_ary if kx_i == ky_i else false_ary) + and self.array_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 59be99e8a2c4099882d7a8816e2872f7d1d0b7fc..848870a9c104b719346a0bb318ec200ba0c93300 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -38,6 +38,7 @@ from arraycontext.container.traversal import ( rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -215,30 +216,40 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): result = result.get()[()] return result - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context queue = actx.queue # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, cl_array.Array) + assert isinstance(y, cl_array.Array) + if x.shape != y.shape: - return false + return false_ary else: return (x == y).all() else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( partial(cl_array.minimum, queue=queue), - [rec_equal(x_i, y_i)for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) result = rec_equal(a, b) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d3d018d6434cefe933c7774225605c2954128e71..c6508e3aba7ac8b9dcf9373c5ca69fb5f6d416d5 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -22,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from functools import partial, reduce -from typing import Any +from typing import Any, cast import numpy as np @@ -34,6 +34,7 @@ from arraycontext.container.traversal import ( rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -171,31 +172,41 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): partial(reduce, pt.logical_or), lambda subary: pt.any(subary), a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, pt.Array) + assert isinstance(y, pt.Array) + if x.shape != y.shape: - return false + return false_ary else: - return pt.all(pt.equal(x, y)) + return pt.all(cast(pt.Array, pt.equal(x, y))) else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( pt.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) - return rec_equal(a, b) + return cast(Array, rec_equal(a, b)) # }}}