From e98e577b37229ebb28d36ae33c87cf130f557bcf Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Tue, 21 Sep 2021 10:31:51 -0500 Subject: [PATCH] add array_equal to PyOpenCLArrayContext.np --- arraycontext/impl/pyopencl/fake_numpy.py | 36 ++++++++++++++++++++++-- test/test_arraycontext.py | 31 ++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index a984ef3..0ba81e1 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -31,9 +31,13 @@ import operator from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace +from arraycontext.container import is_array_container from arraycontext.container.traversal import ( - rec_multimap_array_container, rec_map_array_container, + rec_map_array_container, + rec_multimap_array_container, + multimap_reduce_array_container, rec_map_reduce_array_container, + rec_multimap_reduce_array_container, ) try: @@ -201,7 +205,6 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_array_container(_rec_ravel, a) def vdot(self, x, y, dtype=None): - from arraycontext import rec_multimap_reduce_array_container result = rec_multimap_reduce_array_container( sum, partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue), @@ -233,6 +236,35 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): result = result.get()[()] return result + def array_equal(self, a, b): + def as_device_scalar(bool_value): + import numpy as np + return self._array_context.from_numpy( + np.array(int(bool_value), dtype=np.int8)) + + # Do recursion separately from device-to-host conversion (below) so that + # we don't pass host booleans to cl_array.minimum + def rec_equal(x, y): + if type(x) != type(y): + return as_device_scalar(False) + elif not is_array_container(x): + if x.shape != y.shape: + return as_device_scalar(False) + else: + return (x == y).all() + else: + queue = self._array_context.queue + reduce_func = partial(reduce, partial(cl_array.minimum, queue=queue)) + map_func = rec_equal + return multimap_reduce_array_container( + reduce_func, map_func, x, y) + + result = rec_equal(a, b) + + if not self._array_context._force_device_scalars: + result = result.get()[()] + return result + # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 9e855b0..268eb3e 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -547,6 +547,37 @@ def test_any_all_same_as_numpy(actx_factory, sym_name): assert_close_to_numpy_in_containers(actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), [1 - ary_all]) + +def test_array_equal_same_as_numpy(actx_factory): + actx = actx_factory() + + sym_name = "array_equal" + if not hasattr(actx.np, sym_name): + pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") + + rng = np.random.default_rng() + ary = rng.integers(0, 2, 512) + ary_copy = ary.copy() + ary_diff_values = np.ones(512) + ary_diff_shape = np.ones(511) + ary_diff_type = DOFArray(actx, (np.ones(512),)) + + # Equal + assert_close_to_numpy_in_containers(actx, + lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary, ary_copy]) + + # Different values + assert_close_to_numpy_in_containers(actx, + lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary, ary_diff_values]) + + # Different shapes + assert_close_to_numpy_in_containers(actx, + lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary, ary_diff_shape]) + + # Different types + assert not actx.np.array_equal(ary, ary_diff_type) + + # }}} -- GitLab