From 3b8d5644fe48e95ec7ea85284164b3b7a75f5e1b Mon Sep 17 00:00:00 2001
From: Thomas Gibson <gibsonthomas1120@hotmail.com>
Date: Wed, 28 Apr 2021 14:40:47 -0500
Subject: [PATCH] Rewrite wadg inverse for readability/debuggability

---
 grudge/op.py | 45 +++++++++++++++++----------------------------
 1 file changed, 17 insertions(+), 28 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index b6d5e635..8acf3c83 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -456,36 +456,25 @@ def _apply_inverse_mass_operator(dcoll, dd_out, dd_in, vec):
     actx = vec.array_context
     inv_area_elements = 1./area_element(actx, dcoll, dd=dd_in)
     if use_wadg:
-        # FIXME: This feels just wrong; can't we be smarter here?
-        # NOTE: Could just write a loopy program directly
-        return DOFArray(
-            actx,
-            tuple(actx.einsum("ik,km,em,mj,ej->ei",
-                            reference_inverse_mass_matrix(
-                                actx,
-                                element_group=grp
-                            ),
-                            reference_mass_matrix(
-                                actx,
-                                out_element_group=grp,
-                                in_element_group=grp
-                            ),
-                            iae_i,
-                            reference_inverse_mass_matrix(
-                                actx,
-                                element_group=grp
-                            ),
-                            vec_i,
-                            # Ew. Fix this `mass_inv_mat_x` business
-                            arg_names=(
-                                "mass_inv_mat_1", "mass_mat",
-                                "jac_det_inv", "mass_inv_mat_2", "vec"
-                            ),
+        # FIXME: Think of how to compose existing functions here...
+        # NOTE: Rewritten for readability/debuggability
+        result = discr.empty(actx, dtype=vec.entry_dtype)
+        grps = discr.groups
+        data = []
+        for grp, jac_inv, x in zip(grps, inv_area_elements, vec):
+            ref_mass = reference_mass_matrix(actx,
+                                             out_element_group=grp,
+                                             in_element_group=grp)
+            ref_mass_inv = reference_inverse_mass_matrix(actx,
+                                                         element_group=grp)
+            data.append(
+                # Based on https://arxiv.org/pdf/1608.03836.pdf
+                # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv
+                actx.einsum("ik,km,em,mj,ej->ei",
+                            ref_mass_inv, ref_mass, jac_inv, ref_mass_inv, x,
                             tagged=(MassOperatorTag(),))
-
-                for grp, iae_i, vec_i in zip(discr.groups, inv_area_elements, vec)
             )
-        )
+        return DOFArray(actx, data=tuple(data))
     else:
         return DOFArray(
             actx,
-- 
GitLab