diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 81461b71aa9008e52fcb379532cfc51ea5de4b2c..4405b5ddd23949478a1dd9fe9f5a941f97e003fe 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -224,6 +224,11 @@ class _PyOpenCLFakeNumpyNamespace(_BaseFakeNumpyNamespace): return super().__getattr__(name) + @obj_array_vectorized_n_args + def where(self, criterion, then, else_): + import pyopencl.array as cl_array + return cl_array.if_positive(criterion.astype(np.bool), then, else_) + class PyOpenCLArrayContext(ArrayContext): """ diff --git a/test/test_meshmode.py b/test/test_meshmode.py index e6312f55cc7a8e88c62988ebb98f28edab5fb98e..697b7662b52921c220b535c6bd969e055db5e837 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1478,6 +1478,17 @@ def test_array_context_np_workalike(ctx_factory): assert np.allclose(actx_result, ref_result) + n_args = 2 + args = [np.random.randn(discr.ndofs) for i in range(n_args)] + ref_result = np.where(((args[0] - args[1]) > 0), args[0], args[1]) + + actx_args = [unflatten(actx, discr, actx.from_numpy(arg)) for arg in args] + actx_result = actx.to_numpy( + flatten(actx.np.where(((args[0] - args[1]) > 0), args[0], args[1])) + ) + + assert np.allclose(actx_result, ref_result) + if __name__ == "__main__": import sys