From 4b2287878a88e454e030a2423d4087d9fd5e3434 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Tue, 1 Jun 2021 11:27:38 -0500
Subject: [PATCH] copy norm() from PyOpenCLArrayContext

---
 arraycontext/impl/pytato.py | 52 ++++++++++++++++++++++++++++++++++---
 test/test_arraycontext.py   |  2 +-
 2 files changed, 50 insertions(+), 4 deletions(-)

diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py
index c021829..25220e6 100644
--- a/arraycontext/impl/pytato.py
+++ b/arraycontext/impl/pytato.py
@@ -36,11 +36,57 @@ from pytools.tag import Tag
 from numbers import Number
 import loopy as lp
 
+from arraycontext.container import serialize_container, is_array_container
 
-class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
-    def norm(self, array, ord=None):
-        raise NotImplementedError
 
+class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
+    def norm(self, ary, ord=None):
+        from numbers import Number
+        if isinstance(ary, Number):
+            return abs(ary)
+
+        if ord is None:
+            ord = 2
+
+        try:
+            from meshmode.dof_array import DOFArray
+        except ImportError:
+            pass
+        else:
+            if isinstance(ary, DOFArray):
+                from warnings import warn
+                warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
+                        "(DOFArrays use 2D arrays internally, and "
+                        "actx.np.linalg.norm should compute matrix norms of those.) "
+                        "This will stop working in 2022. "
+                        "Use meshmode.dof_array.flat_norm instead.",
+                        DeprecationWarning, stacklevel=2)
+
+                import numpy.linalg as la
+                return la.norm(
+                        [self.norm(_flatten_array(subary), ord=ord)
+                            for _, subary in serialize_container(ary)],
+                        ord=ord)
+
+        if is_array_container(ary):
+            import numpy.linalg as la
+            return la.norm(
+                    [self.norm(subary, ord=ord)
+                        for _, subary in serialize_container(ary)],
+                    ord=ord)
+
+        if len(ary.shape) != 1:
+            raise NotImplementedError("only vector norms are implemented")
+
+        if ary.size == 0:
+            return 0
+
+        if ord == np.inf:
+            return self._array_context.np.max(abs(ary))
+        elif isinstance(ord, Number) and ord > 0:
+            return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
+        else:
+            raise NotImplementedError(f"unsupported value of 'ord': {ord}")
 
 class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
     def _get_fake_numpy_linalg_namespace(self):
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index abd9843..119a9fc 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -212,7 +212,7 @@ def test_actx_stack(actx_factory):
     ndofs = 5000
     args = [np.random.randn(ndofs) for i in range(10)]
 
-    assert_close_to_numpy_in_containers(
+    assert_close_to_numpy(
             actx, lambda _np, *_args: _np.stack(_args), args)
 
 
-- 
GitLab