From bf68ea8a6dc06b8600a407e36131c1f6edd09e28 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sun, 16 Aug 2020 20:41:11 -0500
Subject: [PATCH] fix max norm handling

---
 grudge/symbolic/dofdesc_inference.py |  2 ++
 grudge/symbolic/operators.py         |  5 ++--
 test/test_grudge.py                  | 40 ++++++++++++++++++++++++++++
 3 files changed, 44 insertions(+), 3 deletions(-)

diff --git a/grudge/symbolic/dofdesc_inference.py b/grudge/symbolic/dofdesc_inference.py
index 38c77512..f5f11f0a 100644
--- a/grudge/symbolic/dofdesc_inference.py
+++ b/grudge/symbolic/dofdesc_inference.py
@@ -144,6 +144,8 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin):
         return self.map_multi_child(expr, expr.children)
 
     map_product = map_sum
+    map_max = map_sum
+    map_min = map_sum
 
     def map_quotient(self, expr):
         return self.map_multi_child(expr, (expr.numerator, expr.denominator))
diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py
index b5d24060..b47611d7 100644
--- a/grudge/symbolic/operators.py
+++ b/grudge/symbolic/operators.py
@@ -730,11 +730,10 @@ def norm(p, arg, dd=None):
 
     elif p == np.Inf:
         result = NodalMax(dd_in=dd)(prim.fabs(arg))
-        from pymbolic.primitives import Max
 
         if isinstance(result, np.ndarray):
-            from functools import reduce
-            result = reduce(Max, result)
+            from pymbolic.primitives import Max
+            result = Max(result)
 
         return result
 
diff --git a/test/test_grudge.py b/test/test_grudge.py
index b6d82a3f..60f58ddf 100644
--- a/test/test_grudge.py
+++ b/test/test_grudge.py
@@ -645,6 +645,46 @@ def test_function_symbol_array(ctx_factory, array_type):
     assert isinstance(norm, float)
 
 
+@pytest.mark.parametrize("p", [2, np.inf])
+def test_norm_obj_array(ctx_factory, p):
+    """Test :func:`grudge.symbolic.operators.norm` for object arrays."""
+
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+    actx = PyOpenCLArrayContext(queue)
+
+    from meshmode.mesh.generation import generate_regular_rect_mesh
+    dim = 2
+    mesh = generate_regular_rect_mesh(
+            a=(-0.5,)*dim, b=(0.5,)*dim,
+            n=(8,)*dim, order=1)
+    discr = DGDiscretizationWithBoundaries(actx, mesh, order=4)
+
+    w = make_obj_array([1.0, 2.0, 3.0])[:dim]
+
+    # {{ scalar
+
+    sym_w = sym.var("w")
+    norm = bind(discr, sym.norm(p, sym_w))(actx, w=w[0])
+
+    norm_exact = w[0]
+    logger.info("norm: %.5e %.5e", norm, norm_exact)
+    # assert abs(norm - norm_exact) < 1.0e-14
+
+    # }}}
+
+    # {{{ vector
+
+    sym_w = sym.make_sym_array("w", dim)
+    norm = bind(discr, sym.norm(p, sym_w))(actx, w=w)
+
+    norm_exact = np.sqrt(np.sum(w**2)) if p == 2 else np.max(w)
+    logger.info("norm: %.5e %.5e", norm, norm_exact)
+    # assert abs(norm - norm_exact) < 1.0e-14
+
+    # }}}
+
+
 # You can test individual routines by typing
 # $ python test_grudge.py 'test_routine()'
 
-- 
GitLab