From 5706e44ae9a9fd19055825e41c4a8d7a14de1643 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Wed, 29 Jun 2022 11:58:37 -0500 Subject: [PATCH] Support empty containers in `array_equal` (#182) * support empty containers in array_equal * support empty containers in array_equal (jax) --- arraycontext/impl/jax/fake_numpy.py | 5 +++-- arraycontext/impl/pyopencl/fake_numpy.py | 5 +++-- arraycontext/impl/pytato/fake_numpy.py | 5 +++-- test/test_arraycontext.py | 7 ++++++- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 8a72d9a..394fbd5 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -129,6 +129,7 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead + true = actx.from_numpy(numpy.int8(True)) false = actx.from_numpy(numpy.int8(False)) def rec_equal(x, y): @@ -145,8 +146,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): else: return reduce( jnp.logical_and, - [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] - ) + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true) return rec_equal(a, b) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index bae5b34..3c9be87 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -169,6 +169,7 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): queue = actx.queue # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead + true = actx.from_numpy(np.int8(True)) false = actx.from_numpy(np.int8(False)) def rec_equal(x, y): @@ -185,8 +186,8 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): else: return reduce( partial(cl_array.minimum, queue=queue), - [rec_equal(ix, iy)for (_, ix), (_, iy) in iterable] - ) + [rec_equal(ix, iy)for (_, ix), (_, iy) in iterable], + true) result = rec_equal(a, b) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index babdf4a..0f219d1 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -148,6 +148,7 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead + true = actx.from_numpy(np.int8(True)) false = actx.from_numpy(np.int8(False)) def rec_equal(x, y): @@ -164,8 +165,8 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): else: return reduce( pt.logical_and, - [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] - ) + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true) return rec_equal(a, b) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 154af2f..928f446 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -576,7 +576,7 @@ def test_any_all_same_as_numpy(actx_factory, sym_name): lambda _np, *_args: getattr(_np, sym_name)(*_args), [1 - ary_all]) -def test_array_equal_same_as_numpy(actx_factory): +def test_array_equal(actx_factory): actx = actx_factory() sym_name = "array_equal" @@ -605,6 +605,11 @@ def test_array_equal_same_as_numpy(actx_factory): # Different types assert not actx.to_numpy(actx.np.array_equal(ary, ary_diff_type)) + # Empty + ary_empty = np.empty((5, 0), dtype=object) + ary_empty_copy = ary_empty.copy() + assert actx.to_numpy(actx.np.array_equal(ary_empty, ary_empty_copy)) + # }}} -- GitLab