From 2b0ae43ba4d1bcbb85c63693191a79deb4fff7cf Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 12 Jul 2021 16:27:45 -0500 Subject: [PATCH] add vdot to pyopencl array context --- arraycontext/impl/pyopencl/fake_numpy.py | 21 ++++++++++++++ test/test_arraycontext.py | 37 ++++++++++++++++++++---- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 20e0d48..bd9eb08 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -162,6 +162,27 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_array_container(_rec_ravel, a) + def vdot(self, x, y, dtype=None): + import pyopencl.array as cl_array + from arraycontext import is_array_container, serialize_container + + def _rec_vdot(xi, yi): + if is_array_container(xi): + assert type(xi) == type(yi) + return sum(_rec_vdot(subxi, subyi) + for (_, subxi), (_, subyi) in zip( + serialize_container(xi), serialize_container(yi) + )) + else: + result = cl_array.vdot(xi, yi, + dtype=dtype, queue=self._array_context.queue) + if not self._array_context._force_device_scalars: + result = result.get()[()] + + return result + + return _rec_vdot(x, y) + # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index a653655..4642806 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -208,9 +208,12 @@ def assert_close_to_numpy_in_containers(actx, op, args): DOFArray(actx, (actx.from_numpy(arg),)) if isinstance(arg, np.ndarray) else arg for arg in args] - actx_result = actx.to_numpy(op(actx.np, *dofarray_args)[0]) - assert np.allclose(actx_result, ref_result) + actx_result = op(actx.np, *dofarray_args) + if isinstance(actx_result, DOFArray): + actx_result = actx_result[0] + + assert np.allclose(actx.to_numpy(actx_result), ref_result) # }}} @@ -219,9 +222,12 @@ def assert_close_to_numpy_in_containers(actx, op, args): obj_array_args = [ make_obj_array([arg]) if isinstance(arg, DOFArray) else arg for arg in dofarray_args] - obj_array_result = actx.to_numpy(op(actx.np, *obj_array_args)[0][0]) - assert np.allclose(obj_array_result, ref_result) + obj_array_result = op(actx.np, *obj_array_args) + if isinstance(obj_array_result, np.ndarray): + obj_array_result = obj_array_result[0][0] + + assert np.allclose(actx.to_numpy(obj_array_result), ref_result) # }}} @@ -238,9 +244,12 @@ def assert_close_to_numpy_in_containers(actx, op, args): ("maximum", 2), ("where", 3), ("conj", 1), + ("vdot", 2), ]) def test_array_context_np_workalike(actx_factory, sym_name, n_args): actx = actx_factory() + if not hasattr(actx.np, sym_name): + pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") ndofs = 5000 args = [np.random.randn(ndofs) for i in range(n_args)] @@ -780,6 +789,8 @@ def test_numpy_conversion(actx_factory): # }}} +# {{{ test actx.np.linalg.norm + @pytest.mark.parametrize("norm_ord", [2, np.inf]) def test_norm_complex(actx_factory, norm_ord): actx = actx_factory() @@ -809,6 +820,8 @@ def test_norm_ord_none(actx_factory, ndim): np.testing.assert_allclose(actx.to_numpy(norm_a), norm_a_ref) +# }}} + # {{{ test_actx_compile helpers @@ -828,8 +841,6 @@ def scale_and_orthogonalize(alpha, vel): vel) return Velocity2D(-scaled_vel.v, scaled_vel.u, actx) -# }}} - def test_actx_compile(actx_factory): from arraycontext import (to_numpy, from_numpy) @@ -848,6 +859,10 @@ def test_actx_compile(actx_factory): np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) +# }}} + + +# {{{ test_container_equality def test_container_equality(actx_factory): actx = actx_factory() @@ -865,6 +880,10 @@ def test_container_equality(actx_factory): assert isinstance(bcast_dc_of_dofs == bcast_dc_of_dofs_2, MyContainerDOFBcast) +# }}} + + +# {{{ test_abs_complex def test_abs_complex(actx_factory): actx = actx_factory() @@ -876,6 +895,10 @@ def test_abs_complex(actx_factory): assert abs_a.dtype == abs_a_ref.dtype np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref) +# }}} + + +# {{{ test_leaf_array_type_broadcasting @with_container_arithmetic( bcast_obj_array=True, @@ -925,6 +948,8 @@ def test_leaf_array_type_broadcasting(actx_factory): np.testing.assert_allclose(actx.to_numpy(bar.u[0]), actx.to_numpy(quuz.u[0])) +# }}} + if __name__ == "__main__": import sys -- GitLab