From 6b9f8297000b23db2b1832cb0f74a7a131e2b30e Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Wed, 22 Sep 2021 11:25:35 -0500
Subject: [PATCH] avoid call_loopy by taking a combination of sum/max/min and
 broadcast_to

---
 grudge/reductions.py | 65 ++++++++++++++++++++++++++------------------
 1 file changed, 38 insertions(+), 27 deletions(-)

diff --git a/grudge/reductions.py b/grudge/reductions.py
index ec301576..08a87335 100644
--- a/grudge/reductions.py
+++ b/grudge/reductions.py
@@ -295,35 +295,46 @@ def _apply_elementwise_reduction(
 
     actx = vec.array_context
 
-    @memoize_in(actx, (_apply_elementwise_reduction,
-                       "elementwise_%s_prg" % op_name))
-    def elementwise_prg():
-        # FIXME: This computes the reduction value redundantly for each
-        # output DOF.
-        t_unit = make_loopy_program(
-            [
-                "{[iel]: 0 <= iel < nelements}",
-                "{[idof, jdof]: 0 <= idof, jdof < ndofs}"
-            ],
-            """
-                result[iel, idof] = %s(jdof, operand[iel, jdof])
-            """ % op_name,
-            name="grudge_elementwise_%s_knl" % op_name
+    if actx.supports_nonscalar_broadcasting:
+        return DOFArray(
+            actx,
+            data=tuple(
+                actx.np.broadcast_to((getattr(actx.np, op_name)(vec_i, axis=1)
+                                      .reshape(-1, 1)),
+                                     vec_i.shape)
+                for vec_i in vec
+            )
         )
-        import loopy as lp
-        from meshmode.transform_metadata import (
-                ConcurrentElementInameTag, ConcurrentDOFInameTag)
-        return lp.tag_inames(t_unit, {
-            "iel": ConcurrentElementInameTag(),
-            "idof": ConcurrentDOFInameTag()})
-
-    return DOFArray(
-        actx,
-        data=tuple(
-            actx.call_loopy(elementwise_prg(), operand=vec_i)["result"]
-            for vec_i in vec
+    else:
+        @memoize_in(actx, (_apply_elementwise_reduction,
+                        "elementwise_%s_prg" % op_name))
+        def elementwise_prg():
+            # FIXME: This computes the reduction value redundantly for each
+            # output DOF.
+            t_unit = make_loopy_program(
+                [
+                    "{[iel]: 0 <= iel < nelements}",
+                    "{[idof, jdof]: 0 <= idof, jdof < ndofs}"
+                ],
+                """
+                    result[iel, idof] = %s(jdof, operand[iel, jdof])
+                """ % op_name,
+                name="grudge_elementwise_%s_knl" % op_name
+            )
+            import loopy as lp
+            from meshmode.transform_metadata import (
+                    ConcurrentElementInameTag, ConcurrentDOFInameTag)
+            return lp.tag_inames(t_unit, {
+                "iel": ConcurrentElementInameTag(),
+                "idof": ConcurrentDOFInameTag()})
+
+        return DOFArray(
+            actx,
+            data=tuple(
+                actx.call_loopy(elementwise_prg(), operand=vec_i)["result"]
+                for vec_i in vec
+            )
         )
-    )
 
 
 def elementwise_sum(dcoll: DiscretizationCollection, *args) -> DOFArray:
-- 
GitLab