From f4b73fc7be7719dfd423d6c640b8d0857e481ed1 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com>
Date: Tue, 6 Jul 2021 00:40:30 -0500
Subject: [PATCH] Broadcast binary ops with device scalars (#502)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* implement broadcasting array binary ops with device scalars

* test broadcasting array binary ops with device scalars

* re-add some asserts (better to be safe)

* be explicit in error msg

op -> operator

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* array binops: include asserts on out.shape as well

* set default args more elegantly

* Array shape checks: save shapes in temporaries

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 pyopencl/array.py       | 119 +++++++++++++++++++++++++++++++---------
 pyopencl/elementwise.py |  39 ++++++++-----
 test/test_array.py      |  24 ++++++++
 3 files changed, 141 insertions(+), 41 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index 9627986d..fb80f2d7 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -104,6 +104,23 @@ def _get_truedivide_dtype(obj1, obj2, queue):
     return result
 
 
+def _get_broadcasted_binary_op_result(obj1, obj2, cq,
+                                      dtype_getter=_get_common_dtype):
+
+    if obj1.shape == obj2.shape:
+        return obj1._new_like_me(dtype_getter(obj1, obj2, cq),
+                                 cq)
+    elif obj1.shape == ():
+        return obj2._new_like_me(dtype_getter(obj1, obj2, cq),
+                                 cq)
+    elif obj2.shape == ():
+        return obj1._new_like_me(dtype_getter(obj1, obj2, cq),
+                                 cq)
+    else:
+        raise NotImplementedError("Broadcasting binary operator with shapes:"
+                                  f" {obj1.shape}, {obj2.shape}.")
+
+
 class InconsistentOpenCLQueueWarning(UserWarning):
     pass
 
@@ -874,11 +891,16 @@ class Array:
     def _axpbyz(out, afac, a, bfac, b, queue=None):
         """Compute ``out = selffac * self + otherfac*other``,
         where *other* is an array."""
-        assert out.shape == a.shape
-        assert out.shape == b.shape
-
+        a_shape = a.shape
+        b_shape = b.shape
+        out_shape = out.shape
+        assert (a_shape == b_shape == out_shape
+                or (a_shape == () and b_shape == out_shape)
+                or (b_shape == () and a_shape == out_shape))
         return elementwise.get_axpbyz_kernel(
-                out.context, a.dtype, b.dtype, out.dtype)
+                out.context, a.dtype, b.dtype, out.dtype,
+                x_is_scalar=(a_shape == ()),
+                y_is_scalar=(b_shape == ()))
 
     @staticmethod
     @elwise_kernel_runner
@@ -893,10 +915,17 @@ class Array:
     @staticmethod
     @elwise_kernel_runner
     def _elwise_multiply(out, a, b, queue=None):
-        assert out.shape == a.shape
-        assert out.shape == b.shape
+        a_shape = a.shape
+        b_shape = b.shape
+        out_shape = out.shape
+        assert (a_shape == b_shape == out_shape
+                or (a_shape == () and b_shape == out_shape)
+                or (b_shape == () and a_shape == out_shape))
         return elementwise.get_multiply_kernel(
-                a.context, a.dtype, b.dtype, out.dtype)
+                a.context, a.dtype, b.dtype, out.dtype,
+                x_is_scalar=(a_shape == ()),
+                y_is_scalar=(b_shape == ())
+        )
 
     @staticmethod
     @elwise_kernel_runner
@@ -910,11 +939,14 @@ class Array:
     @elwise_kernel_runner
     def _div(out, self, other, queue=None):
         """Divides an array by another array."""
-
-        assert self.shape == other.shape
+        assert (self.shape == other.shape == out.shape
+                or (self.shape == () and other.shape == out.shape)
+                or (other.shape == () and self.shape == out.shape))
 
         return elementwise.get_divide_kernel(self.context,
-                self.dtype, other.dtype, out.dtype)
+                self.dtype, other.dtype, out.dtype,
+                x_is_scalar=(self.shape == ()),
+                y_is_scalar=(other.shape == ()))
 
     @staticmethod
     @elwise_kernel_runner
@@ -1027,10 +1059,16 @@ class Array:
     @staticmethod
     @elwise_kernel_runner
     def _array_binop(out, a, b, queue=None, op=None):
-        if a.shape != b.shape:
-            raise ValueError("shapes of binop arguments do not match")
+        a_shape = a.shape
+        b_shape = b.shape
+        out_shape = out.shape
+        assert (a_shape == b_shape == out_shape
+                or (a_shape == () and b_shape == out_shape)
+                or (b_shape == () and a_shape == out_shape))
         return elementwise.get_array_binop_kernel(
-                out.context, op, out.dtype, a.dtype, b.dtype)
+                out.context, op, out.dtype, a.dtype, b.dtype,
+                a_is_scalar=(a_shape == ()),
+                b_is_scalar=(b_shape == ()))
 
     @staticmethod
     @elwise_kernel_runner
@@ -1047,8 +1085,7 @@ class Array:
     def mul_add(self, selffac, other, otherfac, queue=None):
         """Return `selffac * self + otherfac*other`.
         """
-        result = self._new_like_me(
-                _get_common_dtype(self, other, queue or self.queue))
+        result = _get_broadcasted_binary_op_result(self, other, queue or self.queue)
         result.add_event(
                 self._axpbyz(result, selffac, self, otherfac, other))
         return result
@@ -1058,8 +1095,7 @@ class Array:
 
         if isinstance(other, Array):
             # add another vector
-            result = self._new_like_me(
-                    _get_common_dtype(self, other, self.queue))
+            result = _get_broadcasted_binary_op_result(self, other, self.queue)
 
             result.add_event(
                     self._axpbyz(result,
@@ -1087,8 +1123,7 @@ class Array:
         """Substract an array from an array or a scalar from an array."""
 
         if isinstance(other, Array):
-            result = self._new_like_me(
-                    _get_common_dtype(self, other, self.queue))
+            result = _get_broadcasted_binary_op_result(self, other, self.queue)
             result.add_event(
                     self._axpbyz(result,
                         self.dtype.type(1), self,
@@ -1123,6 +1158,10 @@ class Array:
 
     def __iadd__(self, other):
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(
                     self._axpbyz(self,
                         self.dtype.type(1), self,
@@ -1135,6 +1174,10 @@ class Array:
 
     def __isub__(self, other):
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(
                     self._axpbyz(self, self.dtype.type(1), self,
                         other.dtype.type(-1), other))
@@ -1155,8 +1198,7 @@ class Array:
 
     def __mul__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(
-                    _get_common_dtype(self, other, self.queue))
+            result = _get_broadcasted_binary_op_result(self, other, self.queue)
             result.add_event(
                     self._elwise_multiply(result, self, other))
             return result
@@ -1180,6 +1222,10 @@ class Array:
 
     def __imul__(self, other):
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(
                     self._elwise_multiply(self, self, other))
             return self
@@ -1194,15 +1240,17 @@ class Array:
     def __div__(self, other):
         """Divides an array by an array or a scalar, i.e. ``self / other``.
         """
-        common_dtype = _get_truedivide_dtype(self, other, self.queue)
         if isinstance(other, Array):
-            result = self._new_like_me(common_dtype)
+            result = _get_broadcasted_binary_op_result(
+                            self, other, self.queue,
+                            dtype_getter=_get_truedivide_dtype)
             result.add_event(self._div(result, self, other))
             return result
         elif np.isscalar(other):
             if other == 1:
                 return self.copy()
             else:
+                common_dtype = _get_truedivide_dtype(self, other, self.queue)
                 # create a new array for the result
                 result = self._new_like_me(common_dtype)
                 result.add_event(
@@ -1243,6 +1291,10 @@ class Array:
                             .format(self.dtype, common_dtype))
 
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(
                 self._div(self, self, other))
             return self
@@ -1264,7 +1316,8 @@ class Array:
             raise TypeError("Integral types only")
 
         if isinstance(other, Array):
-            result = self._new_like_me(common_dtype)
+            result = _get_broadcasted_binary_op_result(self, other,
+                                                       self.queue)
             result.add_event(self._array_binop(result, self, other, op="&"))
         else:
             # create a new array for the result
@@ -1283,7 +1336,8 @@ class Array:
             raise TypeError("Integral types only")
 
         if isinstance(other, Array):
-            result = self._new_like_me(common_dtype)
+            result = _get_broadcasted_binary_op_result(self, other,
+                                                       self.queue)
             result.add_event(self._array_binop(result, self, other, op="|"))
         else:
             # create a new array for the result
@@ -1302,7 +1356,8 @@ class Array:
             raise TypeError("Integral types only")
 
         if isinstance(other, Array):
-            result = self._new_like_me(common_dtype)
+            result = _get_broadcasted_binary_op_result(self, other,
+                                                       self.queue)
             result.add_event(self._array_binop(result, self, other, op="^"))
         else:
             # create a new array for the result
@@ -1321,6 +1376,10 @@ class Array:
             raise TypeError("Integral types only")
 
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(self._array_binop(self, self, other, op="&"))
         else:
             self.add_event(
@@ -1335,6 +1394,10 @@ class Array:
             raise TypeError("Integral types only")
 
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(self._array_binop(self, self, other, op="|"))
         else:
             self.add_event(
@@ -1349,6 +1412,10 @@ class Array:
             raise TypeError("Integral types only")
 
         if isinstance(other, Array):
+            if (other.shape != self.shape
+                    and other.shape != ()):
+                raise NotImplementedError("Broadcasting binary op with shapes:"
+                                          f" {self.shape}, {other.shape}.")
             self.add_event(self._array_binop(self, self, other, op="^"))
         else:
             self.add_event(
diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 863b2315..c6e4d4bf 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -493,32 +493,36 @@ def real_dtype(dtype):
 
 
 @context_dependent_memoize
-def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z):
+def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z,
+                      x_is_scalar=False, y_is_scalar=False):
     result_t = dtype_to_ctype(dtype_z)
 
     x_is_complex = dtype_x.kind == "c"
     y_is_complex = dtype_y.kind == "c"
 
+    x = "x[0]" if x_is_scalar else "x[i]"
+    y = "y[0]" if y_is_scalar else "y[i]"
+
     if dtype_z.kind == "c":
         # a and b will always be complex here.
         z_ct = complex_dtype_to_name(dtype_z)
 
         if x_is_complex:
-            ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))"
+            ax = f"{z_ct}_mul(a, {z_ct}_cast({x}))"
         else:
-            ax = f"{z_ct}_mulr(a, x[i])"
+            ax = f"{z_ct}_mulr(a, {x})"
 
         if y_is_complex:
-            by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))"
+            by = f"{z_ct}_mul(b, {z_ct}_cast({y}))"
         else:
-            by = f"{z_ct}_mulr(b, y[i])"
+            by = f"{z_ct}_mulr(b, {y})"
 
         result = f"{z_ct}_add({ax}, {by})"
     else:
         # real-only
 
-        ax = f"a*(({result_t}) x[i])"
-        by = f"b*(({result_t}) y[i])"
+        ax = f"a*(({result_t}) {x})"
+        by = f"b*(({result_t}) {y})"
 
         result = f"{ax} + {by}"
 
@@ -594,12 +598,13 @@ def get_axpbz_kernel(context, dtype_a, dtype_x, dtype_b, dtype_z):
 
 
 @context_dependent_memoize
-def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z):
+def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z,
+                        x_is_scalar=False, y_is_scalar=False):
     x_is_complex = dtype_x.kind == "c"
     y_is_complex = dtype_y.kind == "c"
 
-    x = "x[i]"
-    y = "y[i]"
+    x = "x[0]" if x_is_scalar else "x[i]"
+    y = "y[0]" if y_is_scalar else "y[i]"
 
     if x_is_complex and dtype_x != dtype_z:
         x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
@@ -626,13 +631,14 @@ def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z):
 
 
 @context_dependent_memoize
-def get_divide_kernel(context, dtype_x, dtype_y, dtype_z):
+def get_divide_kernel(context, dtype_x, dtype_y, dtype_z,
+                      x_is_scalar=False, y_is_scalar=False):
     x_is_complex = dtype_x.kind == "c"
     y_is_complex = dtype_y.kind == "c"
     z_is_complex = dtype_z.kind == "c"
 
-    x = "x[i]"
-    y = "y[i]"
+    x = "x[0]" if x_is_scalar else "x[i]"
+    y = "y[0]" if y_is_scalar else "y[i]"
 
     if z_is_complex and dtype_x != dtype_y:
         if x_is_complex and dtype_x != dtype_z:
@@ -809,13 +815,16 @@ def get_array_scalar_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b
 
 
 @context_dependent_memoize
-def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b):
+def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b,
+                           a_is_scalar=False, b_is_scalar=False):
+    a = "a[0]" if a_is_scalar else "a[i]"
+    b = "b[0]" if b_is_scalar else "b[i]"
     return get_elwise_kernel(context, [
         VectorArg(dtype_res, "out", with_offset=True),
         VectorArg(dtype_a, "a", with_offset=True),
         VectorArg(dtype_b, "b", with_offset=True),
         ],
-        "out[i] = a[i] %s b[i]" % operator,
+        f"out[i] = {a} {operator} {b}",
         name="binop_kernel")
 
 
diff --git a/test/test_array.py b/test/test_array.py
index 6bc16138..2d74e9ce 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -1621,6 +1621,30 @@ def test_arithmetic_on_non_scalars(ctx_factory):
         ArrayContainer(np.ones(100)) + cl.array.zeros(cq, (10,), dtype=np.float64)
 
 
+@pytest.mark.parametrize("which", ("add", "sub", "mul", "truediv"))
+def test_arithmetic_with_device_scalars(ctx_factory, which):
+    import operator
+    from numpy.random import default_rng
+
+    ctx = ctx_factory()
+    cq = cl.CommandQueue(ctx)
+
+    rng = default_rng()
+    ndim = rng.integers(1, 5)
+
+    shape = tuple(rng.integers(2, 7) for i in range(ndim))
+
+    x_in = rng.random(shape)
+    x_cl = cl_array.to_device(cq, x_in)
+    idx = tuple(rng.integers(0, dim) for dim in shape)
+
+    op = getattr(operator, which)
+    res_cl = op(x_cl, x_cl[idx])
+    res_np = op(x_in, x_in[idx])
+
+    np.testing.assert_allclose(res_cl.get(), res_np)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab