From 95fd7be8645e6d35d71aa903f70d441facc9f0d8 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Thu, 17 Jun 2021 12:30:35 -0500 Subject: [PATCH] pyopenclac.to_numpy special case isscalar --- arraycontext/impl/pyopencl.py | 3 +++ test/test_arraycontext.py | 22 +++++++++------------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 694ba14..4de6d0d 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -323,6 +323,9 @@ class PyOpenCLArrayContext(ArrayContext): return cla.to_device(self.queue, array, allocator=self.allocator) def to_numpy(self, array): + from numpy import isscalar + if isscalar(array): + return array return array.get(queue=self.queue) def call_loopy(self, t_unit, **kwargs): diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 41d18e4..fce673f 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -401,18 +401,16 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory): # {{{ reductions same as numpy - -def test_dof_array_reductions_same_as_numpy(actx_factory): +@pytest.mark.parametrize("op", ["sum", "min", "max"]) +def test_dof_array_reductions_same_as_numpy(actx_factory, op): actx = actx_factory() - for name in ["sum", "min", "max"]: - ary = np.random.randn(3000) - np_red = getattr(np, name)(ary) - actx_red = getattr(actx.np, name)(actx.from_numpy(ary)) - if not np.isscalar(actx_red): - actx_red = actx.to_numpy(actx_red) + ary = np.random.randn(3000) + np_red = getattr(np, op)(ary) + actx_red = getattr(actx.np, op)(actx.from_numpy(ary)) + actx_red = actx.to_numpy(actx_red) - assert np.allclose(np_red, actx_red) + assert np.allclose(np_red, actx_red) # }}} @@ -725,8 +723,7 @@ def test_norm_complex(actx_factory, norm_ord): norm_a_ref = np.linalg.norm(a, norm_ord) norm_a = actx.np.linalg.norm(actx.from_numpy(a), norm_ord) - if not np.isscalar(norm_a): - norm_a = actx.to_numpy(norm_a) + norm_a = actx.to_numpy(norm_a) assert abs(norm_a_ref - norm_a)/norm_a < 1e-13 @@ -745,8 +742,7 @@ def test_norm_ord_none(actx_factory, ndim): norm_a_ref = np.linalg.norm(a, ord=None) norm_a = actx.np.linalg.norm(actx.from_numpy(a), ord=None) - if not np.isscalar(norm_a): - norm_a = actx.to_numpy(norm_a) + norm_a = actx.to_numpy(norm_a) np.testing.assert_allclose(norm_a, norm_a_ref) -- GitLab