From f4b73fc7be7719dfd423d6c640b8d0857e481ed1 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> Date: Tue, 6 Jul 2021 00:40:30 -0500 Subject: [PATCH] Broadcast binary ops with device scalars (#502) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement broadcasting array binary ops with device scalars * test broadcasting array binary ops with device scalars * re-add some asserts (better to be safe) * be explicit in error msg op -> operator Co-authored-by: Andreas Klöckner <inform@tiker.net> * array binops: include asserts on out.shape as well * set default args more elegantly * Array shape checks: save shapes in temporaries Co-authored-by: Andreas Klöckner <inform@tiker.net> --- pyopencl/array.py | 119 +++++++++++++++++++++++++++++++--------- pyopencl/elementwise.py | 39 ++++++++----- test/test_array.py | 24 ++++++++ 3 files changed, 141 insertions(+), 41 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 9627986d..fb80f2d7 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -104,6 +104,23 @@ def _get_truedivide_dtype(obj1, obj2, queue): return result +def _get_broadcasted_binary_op_result(obj1, obj2, cq, + dtype_getter=_get_common_dtype): + + if obj1.shape == obj2.shape: + return obj1._new_like_me(dtype_getter(obj1, obj2, cq), + cq) + elif obj1.shape == (): + return obj2._new_like_me(dtype_getter(obj1, obj2, cq), + cq) + elif obj2.shape == (): + return obj1._new_like_me(dtype_getter(obj1, obj2, cq), + cq) + else: + raise NotImplementedError("Broadcasting binary operator with shapes:" + f" {obj1.shape}, {obj2.shape}.") + + class InconsistentOpenCLQueueWarning(UserWarning): pass @@ -874,11 +891,16 @@ class Array: def _axpbyz(out, afac, a, bfac, b, queue=None): """Compute ``out = selffac * self + otherfac*other``, where *other* is an array.""" - assert out.shape == a.shape - assert out.shape == b.shape - + a_shape = a.shape + b_shape = b.shape + out_shape = out.shape + assert (a_shape == b_shape == out_shape + or (a_shape == () and b_shape == out_shape) + or (b_shape == () and a_shape == out_shape)) return elementwise.get_axpbyz_kernel( - out.context, a.dtype, b.dtype, out.dtype) + out.context, a.dtype, b.dtype, out.dtype, + x_is_scalar=(a_shape == ()), + y_is_scalar=(b_shape == ())) @staticmethod @elwise_kernel_runner @@ -893,10 +915,17 @@ class Array: @staticmethod @elwise_kernel_runner def _elwise_multiply(out, a, b, queue=None): - assert out.shape == a.shape - assert out.shape == b.shape + a_shape = a.shape + b_shape = b.shape + out_shape = out.shape + assert (a_shape == b_shape == out_shape + or (a_shape == () and b_shape == out_shape) + or (b_shape == () and a_shape == out_shape)) return elementwise.get_multiply_kernel( - a.context, a.dtype, b.dtype, out.dtype) + a.context, a.dtype, b.dtype, out.dtype, + x_is_scalar=(a_shape == ()), + y_is_scalar=(b_shape == ()) + ) @staticmethod @elwise_kernel_runner @@ -910,11 +939,14 @@ class Array: @elwise_kernel_runner def _div(out, self, other, queue=None): """Divides an array by another array.""" - - assert self.shape == other.shape + assert (self.shape == other.shape == out.shape + or (self.shape == () and other.shape == out.shape) + or (other.shape == () and self.shape == out.shape)) return elementwise.get_divide_kernel(self.context, - self.dtype, other.dtype, out.dtype) + self.dtype, other.dtype, out.dtype, + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) @staticmethod @elwise_kernel_runner @@ -1027,10 +1059,16 @@ class Array: @staticmethod @elwise_kernel_runner def _array_binop(out, a, b, queue=None, op=None): - if a.shape != b.shape: - raise ValueError("shapes of binop arguments do not match") + a_shape = a.shape + b_shape = b.shape + out_shape = out.shape + assert (a_shape == b_shape == out_shape + or (a_shape == () and b_shape == out_shape) + or (b_shape == () and a_shape == out_shape)) return elementwise.get_array_binop_kernel( - out.context, op, out.dtype, a.dtype, b.dtype) + out.context, op, out.dtype, a.dtype, b.dtype, + a_is_scalar=(a_shape == ()), + b_is_scalar=(b_shape == ())) @staticmethod @elwise_kernel_runner @@ -1047,8 +1085,7 @@ class Array: def mul_add(self, selffac, other, otherfac, queue=None): """Return `selffac * self + otherfac*other`. """ - result = self._new_like_me( - _get_common_dtype(self, other, queue or self.queue)) + result = _get_broadcasted_binary_op_result(self, other, queue or self.queue) result.add_event( self._axpbyz(result, selffac, self, otherfac, other)) return result @@ -1058,8 +1095,7 @@ class Array: if isinstance(other, Array): # add another vector - result = self._new_like_me( - _get_common_dtype(self, other, self.queue)) + result = _get_broadcasted_binary_op_result(self, other, self.queue) result.add_event( self._axpbyz(result, @@ -1087,8 +1123,7 @@ class Array: """Substract an array from an array or a scalar from an array.""" if isinstance(other, Array): - result = self._new_like_me( - _get_common_dtype(self, other, self.queue)) + result = _get_broadcasted_binary_op_result(self, other, self.queue) result.add_event( self._axpbyz(result, self.dtype.type(1), self, @@ -1123,6 +1158,10 @@ class Array: def __iadd__(self, other): if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event( self._axpbyz(self, self.dtype.type(1), self, @@ -1135,6 +1174,10 @@ class Array: def __isub__(self, other): if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event( self._axpbyz(self, self.dtype.type(1), self, other.dtype.type(-1), other)) @@ -1155,8 +1198,7 @@ class Array: def __mul__(self, other): if isinstance(other, Array): - result = self._new_like_me( - _get_common_dtype(self, other, self.queue)) + result = _get_broadcasted_binary_op_result(self, other, self.queue) result.add_event( self._elwise_multiply(result, self, other)) return result @@ -1180,6 +1222,10 @@ class Array: def __imul__(self, other): if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event( self._elwise_multiply(self, self, other)) return self @@ -1194,15 +1240,17 @@ class Array: def __div__(self, other): """Divides an array by an array or a scalar, i.e. ``self / other``. """ - common_dtype = _get_truedivide_dtype(self, other, self.queue) if isinstance(other, Array): - result = self._new_like_me(common_dtype) + result = _get_broadcasted_binary_op_result( + self, other, self.queue, + dtype_getter=_get_truedivide_dtype) result.add_event(self._div(result, self, other)) return result elif np.isscalar(other): if other == 1: return self.copy() else: + common_dtype = _get_truedivide_dtype(self, other, self.queue) # create a new array for the result result = self._new_like_me(common_dtype) result.add_event( @@ -1243,6 +1291,10 @@ class Array: .format(self.dtype, common_dtype)) if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event( self._div(self, self, other)) return self @@ -1264,7 +1316,8 @@ class Array: raise TypeError("Integral types only") if isinstance(other, Array): - result = self._new_like_me(common_dtype) + result = _get_broadcasted_binary_op_result(self, other, + self.queue) result.add_event(self._array_binop(result, self, other, op="&")) else: # create a new array for the result @@ -1283,7 +1336,8 @@ class Array: raise TypeError("Integral types only") if isinstance(other, Array): - result = self._new_like_me(common_dtype) + result = _get_broadcasted_binary_op_result(self, other, + self.queue) result.add_event(self._array_binop(result, self, other, op="|")) else: # create a new array for the result @@ -1302,7 +1356,8 @@ class Array: raise TypeError("Integral types only") if isinstance(other, Array): - result = self._new_like_me(common_dtype) + result = _get_broadcasted_binary_op_result(self, other, + self.queue) result.add_event(self._array_binop(result, self, other, op="^")) else: # create a new array for the result @@ -1321,6 +1376,10 @@ class Array: raise TypeError("Integral types only") if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event(self._array_binop(self, self, other, op="&")) else: self.add_event( @@ -1335,6 +1394,10 @@ class Array: raise TypeError("Integral types only") if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event(self._array_binop(self, self, other, op="|")) else: self.add_event( @@ -1349,6 +1412,10 @@ class Array: raise TypeError("Integral types only") if isinstance(other, Array): + if (other.shape != self.shape + and other.shape != ()): + raise NotImplementedError("Broadcasting binary op with shapes:" + f" {self.shape}, {other.shape}.") self.add_event(self._array_binop(self, self, other, op="^")) else: self.add_event( diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 863b2315..c6e4d4bf 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -493,32 +493,36 @@ def real_dtype(dtype): @context_dependent_memoize -def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z): +def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z, + x_is_scalar=False, y_is_scalar=False): result_t = dtype_to_ctype(dtype_z) x_is_complex = dtype_x.kind == "c" y_is_complex = dtype_y.kind == "c" + x = "x[0]" if x_is_scalar else "x[i]" + y = "y[0]" if y_is_scalar else "y[i]" + if dtype_z.kind == "c": # a and b will always be complex here. z_ct = complex_dtype_to_name(dtype_z) if x_is_complex: - ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))" + ax = f"{z_ct}_mul(a, {z_ct}_cast({x}))" else: - ax = f"{z_ct}_mulr(a, x[i])" + ax = f"{z_ct}_mulr(a, {x})" if y_is_complex: - by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))" + by = f"{z_ct}_mul(b, {z_ct}_cast({y}))" else: - by = f"{z_ct}_mulr(b, y[i])" + by = f"{z_ct}_mulr(b, {y})" result = f"{z_ct}_add({ax}, {by})" else: # real-only - ax = f"a*(({result_t}) x[i])" - by = f"b*(({result_t}) y[i])" + ax = f"a*(({result_t}) {x})" + by = f"b*(({result_t}) {y})" result = f"{ax} + {by}" @@ -594,12 +598,13 @@ def get_axpbz_kernel(context, dtype_a, dtype_x, dtype_b, dtype_z): @context_dependent_memoize -def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z): +def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z, + x_is_scalar=False, y_is_scalar=False): x_is_complex = dtype_x.kind == "c" y_is_complex = dtype_y.kind == "c" - x = "x[i]" - y = "y[i]" + x = "x[0]" if x_is_scalar else "x[i]" + y = "y[0]" if y_is_scalar else "y[i]" if x_is_complex and dtype_x != dtype_z: x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x) @@ -626,13 +631,14 @@ def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z): @context_dependent_memoize -def get_divide_kernel(context, dtype_x, dtype_y, dtype_z): +def get_divide_kernel(context, dtype_x, dtype_y, dtype_z, + x_is_scalar=False, y_is_scalar=False): x_is_complex = dtype_x.kind == "c" y_is_complex = dtype_y.kind == "c" z_is_complex = dtype_z.kind == "c" - x = "x[i]" - y = "y[i]" + x = "x[0]" if x_is_scalar else "x[i]" + y = "y[0]" if y_is_scalar else "y[i]" if z_is_complex and dtype_x != dtype_y: if x_is_complex and dtype_x != dtype_z: @@ -809,13 +815,16 @@ def get_array_scalar_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b @context_dependent_memoize -def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b): +def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b, + a_is_scalar=False, b_is_scalar=False): + a = "a[0]" if a_is_scalar else "a[i]" + b = "b[0]" if b_is_scalar else "b[i]" return get_elwise_kernel(context, [ VectorArg(dtype_res, "out", with_offset=True), VectorArg(dtype_a, "a", with_offset=True), VectorArg(dtype_b, "b", with_offset=True), ], - "out[i] = a[i] %s b[i]" % operator, + f"out[i] = {a} {operator} {b}", name="binop_kernel") diff --git a/test/test_array.py b/test/test_array.py index 6bc16138..2d74e9ce 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1621,6 +1621,30 @@ def test_arithmetic_on_non_scalars(ctx_factory): ArrayContainer(np.ones(100)) + cl.array.zeros(cq, (10,), dtype=np.float64) +@pytest.mark.parametrize("which", ("add", "sub", "mul", "truediv")) +def test_arithmetic_with_device_scalars(ctx_factory, which): + import operator + from numpy.random import default_rng + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + rng = default_rng() + ndim = rng.integers(1, 5) + + shape = tuple(rng.integers(2, 7) for i in range(ndim)) + + x_in = rng.random(shape) + x_cl = cl_array.to_device(cq, x_in) + idx = tuple(rng.integers(0, dim) for dim in shape) + + op = getattr(operator, which) + res_cl = op(x_cl, x_cl[idx]) + res_np = op(x_in, x_in[idx]) + + np.testing.assert_allclose(res_cl.get(), res_np) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab