From 195b9416b044a24e76f31aeaf5ceaeaf62d0052c Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> Date: Mon, 14 Jun 2021 13:20:01 -0500 Subject: [PATCH] Mimic numpy's norm when ord is None (#30) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mimic numpy's norm when ord==None - When ord==None, numpy ravels the array and then computes the euclidean norm. * test norm when ord is None * fixup! mimic numpy's norm when ord==None Co-authored-by: Andreas Klöckner <inform@tiker.net> --- arraycontext/impl/pyopencl.py | 10 ++++++++-- test/test_arraycontext.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 3dd5f3f..9361030 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -189,11 +189,17 @@ def _flatten_array(ary): class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): def norm(self, ary, ord=None): from numbers import Number + import pyopencl.array as cla + if isinstance(ary, Number): return abs(ary) - if ord is None: - ord = 2 + if ord is None and isinstance(ary, cla.Array): + if ary.ndim == 1: + ord = 2 + else: + # mimics numpy's norm computation + return self.norm(_flatten_array(ary), ord=2) try: from meshmode.dof_array import DOFArray diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index d8f9ccb..40a5c60 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -713,6 +713,23 @@ def test_norm_complex(actx_factory, norm_ord): assert abs(norm_a_ref - norm_a)/norm_a < 1e-13 +@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5]) +def test_norm_ord_none(actx_factory, ndim): + from numpy.random import default_rng + + actx = actx_factory() + + rng = default_rng() + shape = tuple(rng.integers(2, 7, ndim)) + + a = rng.random(shape) + + norm_a_ref = np.linalg.norm(a, ord=None) + norm_a = actx.np.linalg.norm(actx.from_numpy(a), ord=None) + + np.testing.assert_allclose(norm_a, norm_a_ref) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab