From ed2374fda1e00f74bc7c91943f9da1a0da2003f5 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Tue, 1 Jun 2021 16:32:28 -0500
Subject: [PATCH] refactor norm()

---
 arraycontext/fake_numpy.py    | 28 +++++++++++++++++++-
 arraycontext/impl/pyopencl.py | 26 ++----------------
 arraycontext/impl/pytato.py   | 50 ++---------------------------------
 3 files changed, 31 insertions(+), 73 deletions(-)

diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index 6b0163d..0657ecb 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -24,7 +24,7 @@ THE SOFTWARE.
 
 
 import numpy as np
-from arraycontext.container import is_array_container
+from arraycontext.container import is_array_container, serialize_container
 from arraycontext.container.traversal import (
         rec_map_array_container, multimapped_over_array_containers)
 
@@ -174,6 +174,32 @@ class BaseFakeNumpyLinalgNamespace:
     def __init__(self, array_context):
         self._array_context = array_context
 
+    def norm(self, ary, ord=None):
+        from numbers import Number
+        if isinstance(ary, Number):
+            return abs(ary)
+
+        if is_array_container(ary):
+            import numpy.linalg as la
+            return la.norm(
+                    [self.norm(subary, ord=ord)
+                        for _, subary in serialize_container(ary)],
+                    ord=ord)
+
+        if len(ary.shape) != 1:
+            raise NotImplementedError("only vector norms are implemented")
+
+        if ary.size == 0:
+            return 0
+
+        if ord == np.inf:
+            return self._array_context.np.max(abs(ary))
+        elif isinstance(ord, Number) and ord > 0:
+            return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
+        else:
+            raise NotImplementedError(f"unsupported value of 'ord': {ord}")
+
+
 # }}}
 
 
diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py
index 1906509..e50a58d 100644
--- a/arraycontext/impl/pyopencl.py
+++ b/arraycontext/impl/pyopencl.py
@@ -39,7 +39,7 @@ from arraycontext.metadata import FirstAxisIsElementsTag
 from arraycontext.fake_numpy import \
         BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
 from arraycontext.container.traversal import rec_multimap_array_container
-from arraycontext.container import serialize_container, is_array_container
+from arraycontext.container import serialize_container
 from arraycontext.context import ArrayContext
 
 
@@ -165,10 +165,6 @@ def _flatten_array(ary):
 
 class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
     def norm(self, ary, ord=None):
-        from numbers import Number
-        if isinstance(ary, Number):
-            return abs(ary)
-
         if ord is None:
             ord = 2
 
@@ -192,25 +188,7 @@ class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
                             for _, subary in serialize_container(ary)],
                         ord=ord)
 
-        if is_array_container(ary):
-            import numpy.linalg as la
-            return la.norm(
-                    [self.norm(subary, ord=ord)
-                        for _, subary in serialize_container(ary)],
-                    ord=ord)
-
-        if len(ary.shape) != 1:
-            raise NotImplementedError("only vector norms are implemented")
-
-        if ary.size == 0:
-            return 0
-
-        if ord == np.inf:
-            return self._array_context.np.max(abs(ary))
-        elif isinstance(ord, Number) and ord > 0:
-            return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
-        else:
-            raise NotImplementedError(f"unsupported value of 'ord': {ord}")
+        return super().norm(ary, ord)
 
 # }}}
 
diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py
index 5ecf9cf..67fd554 100644
--- a/arraycontext/impl/pytato.py
+++ b/arraycontext/impl/pytato.py
@@ -36,57 +36,11 @@ from pytools.tag import Tag
 from numbers import Number
 import loopy as lp
 
-from arraycontext.container import serialize_container, is_array_container
-
 
 class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
     def norm(self, ary, ord=None):
-        from numbers import Number
-        if isinstance(ary, Number):
-            return abs(ary)
-
-        if ord is None:
-            ord = 2
-
-        try:
-            from meshmode.dof_array import DOFArray
-        except ImportError:
-            pass
-        else:
-            if isinstance(ary, DOFArray):
-                from warnings import warn
-                warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
-                        "(DOFArrays use 2D arrays internally, and "
-                        "actx.np.linalg.norm should compute matrix norms of those.) "
-                        "This will stop working in 2022. "
-                        "Use meshmode.dof_array.flat_norm instead.",
-                        DeprecationWarning, stacklevel=2)
-
-                import numpy.linalg as la
-                return la.norm(
-                        [self.norm(_flatten_array(subary), ord=ord)
-                            for _, subary in serialize_container(ary)],
-                        ord=ord)
-
-        if is_array_container(ary):
-            import numpy.linalg as la
-            return la.norm(
-                    [self.norm(subary, ord=ord)
-                        for _, subary in serialize_container(ary)],
-                    ord=ord)
-
-        if len(ary.shape) != 1:
-            raise NotImplementedError("only vector norms are implemented")
-
-        if ary.size == 0:
-            return 0
-
-        if ord == np.inf:
-            return self._array_context.np.max(abs(ary))
-        elif isinstance(ord, Number) and ord > 0:
-            return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
-        else:
-            raise NotImplementedError(f"unsupported value of 'ord': {ord}")
+        # FIXME: handle isinstance(ary, DOFArray) case
+        return super().norm(ary, ord)
 
 
 class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
-- 
GitLab