From 5ae2a0010db563f537479e996a93b832dfb09822 Mon Sep 17 00:00:00 2001
From: Thomas Gibson <gibsonthomas1120@hotmail.com>
Date: Tue, 11 May 2021 20:11:38 -0500
Subject: [PATCH] Remove stateful updates in face mass matrix application

---
 grudge/op.py | 68 +++++++++++++++++++++++++++-------------------------
 1 file changed, 36 insertions(+), 32 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index 04082d16..7f953014 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -709,50 +709,54 @@ def _apply_face_mass_operator(dcoll, dd, vec):
     face_discr = dcoll.discr_from_dd(dd)
     dtype = vec.entry_dtype
     actx = vec.array_context
-    surf_area_elements = area_element(actx, dcoll, dd=dd)
 
     @memoize_in(actx, (_apply_face_mass_operator, "face_mass_knl"))
     def prg():
         return make_loopy_program(
-            """{[iel,idof,f,j]:
-                0<=iel<nelements and
-                0<=f<nfaces and
-                0<=idof<nvol_nodes and
-                0<=j<nface_nodes}""",
+            [
+                "{[iel]: 0 <= iel < nelements}",
+                "{[f]: 0 <= f < nfaces}",
+                "{[idof]: 0 <= idof < nvol_nodes}",
+                "{[jdof]: 0 <= jdof < nface_nodes}"
+            ],
             """
-            result[iel,idof] = sum(f, sum(j, mat[idof, f, j]        \
-                                             * jac_surf[f, iel, j]  \
-                                             * vec[f, iel, j]))
+            result[iel, idof] = sum(f, sum(jdof, mat[idof, f, jdof]        \
+                                                 * jac_surf[f, iel, jdof]  \
+                                                 * vec[f, iel, jdof]))
             """,
             name="face_mass"
         )
 
-    result = volm_discr.empty(actx, dtype=dtype)
     assert len(face_discr.groups) == len(volm_discr.groups)
+    surf_area_elements = area_element(actx, dcoll, dd=dd)
 
-    for afgrp, volgrp in zip(face_discr.groups, volm_discr.groups):
-
-        nfaces = volgrp.mesh_el_group.nfaces
-        matrix = reference_face_mass_matrix(
-            actx,
-            face_element_group=afgrp,
-            vol_element_group=volgrp,
-            dtype=dtype
-        )
-        input_view = vec[afgrp.index].reshape(
-            nfaces, volgrp.nelements, afgrp.nunit_dofs
-        )
-        jac_surf = surf_area_elements[afgrp.index].reshape(
-            nfaces, volgrp.nelements, afgrp.nunit_dofs
-        )
-        actx.call_loopy(
-            prg(),
-            mat=matrix,
-            result=result[volgrp.index],
-            jac_surf=jac_surf,
-            vec=input_view
+    return DOFArray(
+        actx,
+        data=tuple(
+            actx.call_loopy(prg(),
+                            mat=reference_face_mass_matrix(
+                                actx,
+                                face_element_group=afgrp,
+                                vol_element_group=vgrp,
+                                dtype=dtype
+                            ),
+                            jac_surf=surf_ae_i.reshape(
+                                vgrp.mesh_el_group.nfaces,
+                                vgrp.nelements,
+                                afgrp.nunit_dofs
+                            ),
+                            vec=vec_i.reshape(
+                                vgrp.mesh_el_group.nfaces,
+                                vgrp.nelements,
+                                afgrp.nunit_dofs
+                            ))["result"]
+
+            for vgrp, afgrp, vec_i, surf_ae_i in zip(volm_discr.groups,
+                                                     face_discr.groups,
+                                                     vec,
+                                                     surf_area_elements)
         )
-    return result
+    )
 
 
 def face_mass(dcoll, *args):
-- 
GitLab