diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 8a72d9aa41b7cf29b8f7e4248e879818c6bcdbf8..394fbd529c62a5a22897b42f99f69927f9aaf4c0 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -129,6 +129,7 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead + true = actx.from_numpy(numpy.int8(True)) false = actx.from_numpy(numpy.int8(False)) def rec_equal(x, y): @@ -145,8 +146,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): else: return reduce( jnp.logical_and, - [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] - ) + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true) return rec_equal(a, b) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index bae5b34fae1bac2939e2396150686cfaf66e80b0..3c9be87dc3f76fadaba9e5270aa276fba735ddad 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -169,6 +169,7 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): queue = actx.queue # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead + true = actx.from_numpy(np.int8(True)) false = actx.from_numpy(np.int8(False)) def rec_equal(x, y): @@ -185,8 +186,8 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): else: return reduce( partial(cl_array.minimum, queue=queue), - [rec_equal(ix, iy)for (_, ix), (_, iy) in iterable] - ) + [rec_equal(ix, iy)for (_, ix), (_, iy) in iterable], + true) result = rec_equal(a, b) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index babdf4a03231bec238c155cbaed2f34f8d30ba37..0f219d139b087a64a79ca38ce55b9e08e7483fda 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -148,6 +148,7 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead + true = actx.from_numpy(np.int8(True)) false = actx.from_numpy(np.int8(False)) def rec_equal(x, y): @@ -164,8 +165,8 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): else: return reduce( pt.logical_and, - [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] - ) + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true) return rec_equal(a, b) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 154af2f97d8abd151235bacfd74d741e3d190085..928f4462813e9097cd5ac70a76d808e980b70668 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -576,7 +576,7 @@ def test_any_all_same_as_numpy(actx_factory, sym_name): lambda _np, *_args: getattr(_np, sym_name)(*_args), [1 - ary_all]) -def test_array_equal_same_as_numpy(actx_factory): +def test_array_equal(actx_factory): actx = actx_factory() sym_name = "array_equal" @@ -605,6 +605,11 @@ def test_array_equal_same_as_numpy(actx_factory): # Different types assert not actx.to_numpy(actx.np.array_equal(ary, ary_diff_type)) + # Empty + ary_empty = np.empty((5, 0), dtype=object) + ary_empty_copy = ary_empty.copy() + assert actx.to_numpy(actx.np.array_equal(ary_empty, ary_empty_copy)) + # }}}