From 2b0ae43ba4d1bcbb85c63693191a79deb4fff7cf Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 12 Jul 2021 16:27:45 -0500
Subject: [PATCH] add vdot to pyopencl array context

---
 arraycontext/impl/pyopencl/fake_numpy.py | 21 ++++++++++++++
 test/test_arraycontext.py                | 37 ++++++++++++++++++++----
 2 files changed, 52 insertions(+), 6 deletions(-)

diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 20e0d48..bd9eb08 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -162,6 +162,27 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
         return rec_map_array_container(_rec_ravel, a)
 
+    def vdot(self, x, y, dtype=None):
+        import pyopencl.array as cl_array
+        from arraycontext import is_array_container, serialize_container
+
+        def _rec_vdot(xi, yi):
+            if is_array_container(xi):
+                assert type(xi) == type(yi)
+                return sum(_rec_vdot(subxi, subyi)
+                    for (_, subxi), (_, subyi) in zip(
+                        serialize_container(xi), serialize_container(yi)
+                    ))
+            else:
+                result = cl_array.vdot(xi, yi,
+                    dtype=dtype, queue=self._array_context.queue)
+                if not self._array_context._force_device_scalars:
+                    result = result.get()[()]
+
+                return result
+
+        return _rec_vdot(x, y)
+
 # }}}
 
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index a653655..4642806 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -208,9 +208,12 @@ def assert_close_to_numpy_in_containers(actx, op, args):
             DOFArray(actx, (actx.from_numpy(arg),))
             if isinstance(arg, np.ndarray) else arg
             for arg in args]
-    actx_result = actx.to_numpy(op(actx.np, *dofarray_args)[0])
 
-    assert np.allclose(actx_result, ref_result)
+    actx_result = op(actx.np, *dofarray_args)
+    if isinstance(actx_result, DOFArray):
+        actx_result = actx_result[0]
+
+    assert np.allclose(actx.to_numpy(actx_result), ref_result)
 
     # }}}
 
@@ -219,9 +222,12 @@ def assert_close_to_numpy_in_containers(actx, op, args):
     obj_array_args = [
             make_obj_array([arg]) if isinstance(arg, DOFArray) else arg
             for arg in dofarray_args]
-    obj_array_result = actx.to_numpy(op(actx.np, *obj_array_args)[0][0])
 
-    assert np.allclose(obj_array_result, ref_result)
+    obj_array_result = op(actx.np, *obj_array_args)
+    if isinstance(obj_array_result, np.ndarray):
+        obj_array_result = obj_array_result[0][0]
+
+    assert np.allclose(actx.to_numpy(obj_array_result), ref_result)
 
     # }}}
 
@@ -238,9 +244,12 @@ def assert_close_to_numpy_in_containers(actx, op, args):
             ("maximum", 2),
             ("where", 3),
             ("conj", 1),
+            ("vdot", 2),
             ])
 def test_array_context_np_workalike(actx_factory, sym_name, n_args):
     actx = actx_factory()
+    if not hasattr(actx.np, sym_name):
+        pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'")
 
     ndofs = 5000
     args = [np.random.randn(ndofs) for i in range(n_args)]
@@ -780,6 +789,8 @@ def test_numpy_conversion(actx_factory):
 # }}}
 
 
+# {{{ test actx.np.linalg.norm
+
 @pytest.mark.parametrize("norm_ord", [2, np.inf])
 def test_norm_complex(actx_factory, norm_ord):
     actx = actx_factory()
@@ -809,6 +820,8 @@ def test_norm_ord_none(actx_factory, ndim):
 
     np.testing.assert_allclose(actx.to_numpy(norm_a), norm_a_ref)
 
+# }}}
+
 
 # {{{ test_actx_compile helpers
 
@@ -828,8 +841,6 @@ def scale_and_orthogonalize(alpha, vel):
                                          vel)
     return Velocity2D(-scaled_vel.v, scaled_vel.u, actx)
 
-# }}}
-
 
 def test_actx_compile(actx_factory):
     from arraycontext import (to_numpy, from_numpy)
@@ -848,6 +859,10 @@ def test_actx_compile(actx_factory):
     np.testing.assert_allclose(result.u, -3.14*v_y)
     np.testing.assert_allclose(result.v, 3.14*v_x)
 
+# }}}
+
+
+# {{{ test_container_equality
 
 def test_container_equality(actx_factory):
     actx = actx_factory()
@@ -865,6 +880,10 @@ def test_container_equality(actx_factory):
 
     assert isinstance(bcast_dc_of_dofs == bcast_dc_of_dofs_2, MyContainerDOFBcast)
 
+# }}}
+
+
+# {{{ test_abs_complex
 
 def test_abs_complex(actx_factory):
     actx = actx_factory()
@@ -876,6 +895,10 @@ def test_abs_complex(actx_factory):
     assert abs_a.dtype == abs_a_ref.dtype
     np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref)
 
+# }}}
+
+
+# {{{ test_leaf_array_type_broadcasting
 
 @with_container_arithmetic(
     bcast_obj_array=True,
@@ -925,6 +948,8 @@ def test_leaf_array_type_broadcasting(actx_factory):
         np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
                                    actx.to_numpy(quuz.u[0]))
 
+# }}}
+
 
 if __name__ == "__main__":
     import sys
-- 
GitLab