From 195b9416b044a24e76f31aeaf5ceaeaf62d0052c Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com>
Date: Mon, 14 Jun 2021 13:20:01 -0500
Subject: [PATCH] Mimic numpy's norm when ord is None (#30)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* mimic numpy's norm when ord==None

- When ord==None, numpy ravels the array and then computes the euclidean
  norm.

* test norm when ord is None

* fixup! mimic numpy's norm when ord==None

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 arraycontext/impl/pyopencl.py | 10 ++++++++--
 test/test_arraycontext.py     | 17 +++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py
index 3dd5f3f..9361030 100644
--- a/arraycontext/impl/pyopencl.py
+++ b/arraycontext/impl/pyopencl.py
@@ -189,11 +189,17 @@ def _flatten_array(ary):
 class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
     def norm(self, ary, ord=None):
         from numbers import Number
+        import pyopencl.array as cla
+
         if isinstance(ary, Number):
             return abs(ary)
 
-        if ord is None:
-            ord = 2
+        if ord is None and isinstance(ary, cla.Array):
+            if ary.ndim == 1:
+                ord = 2
+            else:
+                # mimics numpy's norm computation
+                return self.norm(_flatten_array(ary), ord=2)
 
         try:
             from meshmode.dof_array import DOFArray
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index d8f9ccb..40a5c60 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -713,6 +713,23 @@ def test_norm_complex(actx_factory, norm_ord):
     assert abs(norm_a_ref - norm_a)/norm_a < 1e-13
 
 
+@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
+def test_norm_ord_none(actx_factory, ndim):
+    from numpy.random import default_rng
+
+    actx = actx_factory()
+
+    rng = default_rng()
+    shape = tuple(rng.integers(2, 7, ndim))
+
+    a = rng.random(shape)
+
+    norm_a_ref = np.linalg.norm(a, ord=None)
+    norm_a = actx.np.linalg.norm(actx.from_numpy(a), ord=None)
+
+    np.testing.assert_allclose(norm_a, norm_a_ref)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab