From 4b5f2bfbfcb175b7434b847eb77d03f32801d239 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 2 Oct 2021 20:46:39 -0500 Subject: [PATCH] extend if_positive to broadcast over host and device scalars --- pyopencl/array.py | 53 ++++++++++++++++++++++++++++++++++------- pyopencl/elementwise.py | 28 ++++++++++++++++++---- test/test_array.py | 35 +++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 14 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 905d8583..d48eece6 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2790,7 +2790,12 @@ def reshape(a, shape): @elwise_kernel_runner def _if_positive(result, criterion, then_, else_): return elementwise.get_if_positive_kernel( - result.context, criterion.dtype, then_.dtype) + result.context, criterion.dtype, then_.dtype, + is_then_array=isinstance(then_, Array), + is_else_array=isinstance(else_, Array), + is_then_scalar=then_.shape == (), + is_else_scalar=else_.shape == (), + ) def if_positive(criterion, then_, else_, out=None, queue=None): @@ -2798,9 +2803,9 @@ def if_positive(criterion, then_, else_, out=None, queue=None): contains *then_[i]* if *criterion[i]>0*, else *else_[i]*. """ - if (isinstance(criterion, SCALAR_CLASSES) - and isinstance(then_, SCALAR_CLASSES) - and isinstance(else_, SCALAR_CLASSES)): + is_then_scalar = isinstance(then_, SCALAR_CLASSES) + is_else_scalar = isinstance(else_, SCALAR_CLASSES) + if isinstance(criterion, SCALAR_CLASSES) and is_then_scalar and is_else_scalar: result = np.where(criterion, then_, else_) if out is not None: @@ -2809,16 +2814,46 @@ def if_positive(criterion, then_, else_, out=None, queue=None): return result - if not (criterion.shape == then_.shape == else_.shape): - raise ValueError("shapes do not match") + if is_then_scalar: + then_ = np.array(then_) - if not (then_.dtype == else_.dtype): - raise ValueError("dtypes do not match") + if is_else_scalar: + else_ = np.array(else_) + + if then_.dtype != else_.dtype: + raise ValueError( + f"dtypes do not match: then_ is '{then_.dtype}' and " + f"else_ is '{else_.dtype}'") + + if then_.shape == () and else_.shape == (): + pass + elif then_.shape != () and else_.shape != (): + if not (criterion.shape == then_.shape == else_.shape): + raise ValueError( + f"shapes do not match: 'criterion' has shape {criterion.shape}" + f", 'then_' has shape {then_.shape} and 'else_' has shape " + f"{else_.shape}") + elif then_.shape == (): + if criterion.shape != else_.shape: + raise ValueError( + f"shapes do not match: 'criterion' has shape {criterion.shape}" + f" and 'else_' has shape {else_.shape}") + elif else_.shape == (): + if criterion.shape != then_.shape: + raise ValueError( + f"shapes do not match: 'criterion' has shape {criterion.shape}" + f" and 'then_' has shape {then_.shape}") + else: + raise AssertionError() if out is None: - out = empty_like(then_) + out = empty( + criterion.queue, criterion.shape, then_.dtype, + allocator=criterion.allocator) + event1 = _if_positive(out, criterion, then_, else_, queue=queue) out.add_event(event1) + return out diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index c6e4d4bf..9f51c299 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -1035,14 +1035,32 @@ def get_diff_kernel(context, dtype): @context_dependent_memoize -def get_if_positive_kernel(context, crit_dtype, dtype): +def get_if_positive_kernel( + context, crit_dtype, then_else_dtype, + is_then_array, is_else_array, + is_then_scalar, is_else_scalar): + if is_then_array: + then_ = "then_[0]" if is_then_scalar else "then_[i]" + then_arg = VectorArg(then_else_dtype, "then_", with_offset=True) + else: + assert is_then_scalar + then_ = "then_" + then_arg = ScalarArg(then_else_dtype, "then_") + + if is_else_array: + else_ = "else_[0]" if is_else_scalar else "else_[i]" + else_arg = VectorArg(then_else_dtype, "else_", with_offset=True) + else: + assert is_else_scalar + else_ = "else_" + else_arg = ScalarArg(then_else_dtype, "else_") + return get_elwise_kernel(context, [ - VectorArg(dtype, "result", with_offset=True), + VectorArg(then_else_dtype, "result", with_offset=True), VectorArg(crit_dtype, "crit", with_offset=True), - VectorArg(dtype, "then_", with_offset=True), - VectorArg(dtype, "else_", with_offset=True), + then_arg, else_arg, ], - "result[i] = crit[i] > 0 ? then_[i] : else_[i]", + f"result[i] = crit[i] > 0 ? {then_} : {else_}", name="if_positive") # }}} diff --git a/test/test_array.py b/test/test_array.py index 2d74e9ce..f0510bf8 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1645,6 +1645,41 @@ def test_arithmetic_with_device_scalars(ctx_factory, which): np.testing.assert_allclose(res_cl.get(), res_np) +@pytest.mark.parametrize("then_type", ["array", "host_scalar", "device_scalar"]) +@pytest.mark.parametrize("else_type", ["array", "host_scalar", "device_scalar"]) +def test_if_positive_with_scalars(ctx_factory, then_type, else_type): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + rng = np.random.default_rng() + shape = (512,) + + criterion_np = rng.random(shape) + criterion_cl = cl_array.to_device(cq, criterion_np) + + def _get_array_or_scalar(rtype, value): + if rtype == "array": + ary_np = value + np.zeros(shape, dtype=criterion_cl.dtype) + ary_cl = value + cl_array.zeros_like(criterion_cl) + elif rtype == "host_scalar": + ary_np = ary_cl = value + elif rtype == "device_scalar": + ary_np = value + ary_cl = cl_array.to_device(cq, np.array(value)) + else: + raise ValueError(rtype) + + return ary_np, ary_cl + + then_np, then_cl = _get_array_or_scalar(then_type, 0.0) + else_np, else_cl = _get_array_or_scalar(else_type, 1.0) + + result_cl = cl_array.if_positive(criterion_cl < 0.5, then_cl, else_cl) + result_np = np.where(criterion_np < 0.5, then_np, else_np) + + np.testing.assert_allclose(result_cl.get(), result_np) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab