diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 3fcc7367bc7e05c071e6f3ef3fab6c204ee15e4e..01efaec844a5ef0baf6620a2b08bbe6c705f172d 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -26,8 +26,11 @@ from functools import partial, reduce from arraycontext.fake_numpy import ( BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, ) +from arraycontext.container import is_array_container from arraycontext.container.traversal import ( - rec_multimap_array_container, rec_map_array_container, + rec_map_array_container, + rec_multimap_array_container, + multimap_reduce_array_container, rec_map_reduce_array_container, ) import pytato as pt @@ -158,4 +161,31 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_array_container(_rec_ravel, a) + def any(self, a): + return rec_map_reduce_array_container( + partial(reduce, pt.logical_or), + lambda subary: pt.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container( + partial(reduce, pt.logical_and), + lambda subary: pt.all(subary), a) + + def array_equal(self, a, b): + def as_device_scalar(bool_value): + import numpy as np + return self._array_context.from_numpy( + np.array(int(bool_value), dtype=np.int8)) + + if type(a) != type(b): + return as_device_scalar(False) + elif not is_array_container(a): + if a.shape != b.shape: + return as_device_scalar(False) + else: + return pt.all(pt.equal(a, b)) + else: + return multimap_reduce_array_container( + partial(reduce, pt.logical_and), self.array_equal, a, b) + # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ed7c1b1ef49964d2651666ef318838a42288d476..880009816cecceff9a5780527941d78c3836b06f 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -578,7 +578,7 @@ def test_array_equal_same_as_numpy(actx_factory): lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary, ary_diff_shape]) # Different types - assert not actx.np.array_equal(ary, ary_diff_type) + assert not actx.to_numpy(actx.np.array_equal(ary, ary_diff_type)) # }}}