From 136e4d493b249004b2302662aaea509016c18a23 Mon Sep 17 00:00:00 2001
From: Thomas Gibson <gibsonthomas1120@hotmail.com>
Date: Sat, 15 May 2021 12:33:20 -0500
Subject: [PATCH] Make op.norm more flexible with inputs

---
 grudge/op.py        | 55 +++++++++++++++++++++++----------------------
 test/test_grudge.py | 10 ++++-----
 2 files changed, 32 insertions(+), 33 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index 58aef33b..44b9ede7 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -834,6 +834,21 @@ def face_mass(dcoll, *args):
 
 # {{{ Nodal reductions
 
+def _norm(dcoll, vec, p, dd):
+    if isinstance(vec, Number):
+        return np.fabs(vec)
+    if p == 2:
+        return np.sqrt(
+            nodal_summation(
+                vec * _apply_mass_operator(dcoll, dd, dd, vec)
+            )
+        )
+    elif p == np.inf:
+        return nodal_maximum(dcoll._setup_actx.np.fabs(vec))
+    else:
+        raise NotImplementedError("Unsupported value of p")
+
+
 def norm(dcoll, vec, p, dd=None):
     r"""Return the vector p-norm of a function represented
     by its vector of degrees of freedom *vec*.
@@ -851,35 +866,22 @@ def norm(dcoll, vec, p, dd=None):
     """
     if dd is None:
         dd = dof_desc.DD_VOLUME
-    dd = dof_desc.as_dofdesc(dd)
-
-    if p == 2:
-        norm_squared = nodal_summation(
-            vec * _apply_mass_operator(dcoll, dd, dd, vec)
-        )
 
-        if isinstance(norm_squared, np.ndarray):
-            if len(norm_squared.shape) != 1:
-                raise NotImplementedError("Can only take the norm of vectors")
-
-            norm_squared = sum(norm_squared)
-
-        return np.sqrt(norm_squared)
-
-    elif p == np.inf:
-        actx = dcoll._setup_actx
-        result = nodal_maximum(actx.np.fabs(vec))
-
-        if isinstance(result, np.ndarray):
-            if len(result.shape) != 1:
-                raise NotImplementedError("Can only take the norm of vectors")
-
-            result = np.max(result)
+    dd = dof_desc.as_dofdesc(dd)
 
-        return result
+    if isinstance(vec, np.ndarray):
+        if p == 2:
+            return sum(
+                    norm(dcoll, vec[idx], p, dd=dd)**2
+                    for idx in np.ndindex(vec.shape))**0.5
+        elif p == np.inf:
+            return max(
+                    norm(dcoll, vec[idx], np.inf, dd=dd)
+                    for idx in np.ndindex(vec.shape))
+        else:
+            raise ValueError("unsupported norm order")
 
-    else:
-        raise NotImplementedError("Unsupported value of p")
+    return _norm(dcoll, vec, p, dd)
 
 
 def nodal_sum(dcoll, dd, vec):
@@ -1115,7 +1117,6 @@ def cross_rank_trace_pairs(dcoll, ary, tag=None):
     :class:`~meshmode.dof_array.DOFArray`, or an object
     array of ``DOFArray``\ s of arbitrary shape.
     """
-
     if isinstance(ary, np.ndarray):
         oshape = ary.shape
         comm_vec = ary.flatten()
diff --git a/test/test_grudge.py b/test/test_grudge.py
index 8abf6156..eee6280d 100644
--- a/test/test_grudge.py
+++ b/test/test_grudge.py
@@ -1089,7 +1089,7 @@ def test_function_symbol_array(actx_factory, array_type):
 
 @pytest.mark.parametrize("p", [2, np.inf])
 def test_norm_obj_array(actx_factory, p):
-    """Test :func:`grudge.symbolic.operators.norm` for object arrays."""
+    """Test :func:`grudge.op.norm` for object arrays."""
 
     actx = actx_factory()
 
@@ -1097,14 +1097,13 @@ def test_norm_obj_array(actx_factory, p):
     mesh = mgen.generate_regular_rect_mesh(
             a=(-0.5,)*dim, b=(0.5,)*dim,
             nelements_per_axis=(8,)*dim, order=1)
-    discr = DiscretizationCollection(actx, mesh, order=4)
+    dcoll = DiscretizationCollection(actx, mesh, order=4)
 
     w = make_obj_array([1.0, 2.0, 3.0])[:dim]
 
     # {{ scalar
 
-    sym_w = sym.var("w")
-    norm = bind(discr, sym.norm(p, sym_w))(actx, w=w[0])
+    norm = op.norm(dcoll, w[0], p)
 
     norm_exact = w[0]
     logger.info("norm: %.5e %.5e", norm, norm_exact)
@@ -1114,8 +1113,7 @@ def test_norm_obj_array(actx_factory, p):
 
     # {{{ vector
 
-    sym_w = sym.make_sym_array("w", dim)
-    norm = bind(discr, sym.norm(p, sym_w))(actx, w=w)
+    norm = op.norm(dcoll, w, p)
 
     norm_exact = np.sqrt(np.sum(w**2)) if p == 2 else np.max(w)
     logger.info("norm: %.5e %.5e", norm, norm_exact)
-- 
GitLab