From 04c631a76e61464e96fec4b0a11a8deb12753087 Mon Sep 17 00:00:00 2001
From: Thomas Gibson <gibsonthomas1120@hotmail.com>
Date: Mon, 26 Apr 2021 11:56:23 -0500
Subject: [PATCH] Implement mass inverse and WADG mass inverse routines

---
 grudge/op.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 93 insertions(+)

diff --git a/grudge/op.py b/grudge/op.py
index c3e57c5f..d53eec84 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -416,6 +416,99 @@ def inverse_mass(dcoll, vec):
     return _bound_inverse_mass(dcoll)(u=vec)
 
 
+def reference_inverse_mass_matrix(actx, element_group):
+    @keyed_memoize_in(
+        actx, reference_inverse_mass_matrix,
+        lambda grp: grp.discretization_key())
+    def get_ref_inv_mass_mat(grp):
+        from modepy import inverse_mass_matrix
+        basis = grp.basis_obj()
+
+        return actx.freeze(
+            actx.from_numpy(
+                inverse_mass_matrix(basis.functions, grp.unit_nodes)
+            )
+        )
+
+    return get_ref_inv_mass_mat(element_group)
+
+
+def _apply_inverse_mass_operator(dcoll, dd_out, dd_in, vec):
+    from grudge.geometry import area_element
+
+    if dd_out != dd_in:
+        raise ValueError(
+            "Cannot compute inverse of a mass matrix mapping "
+            "between different element groups; inverse is not "
+            "guaranteed to be well-defined"
+        )
+    discr = dcoll.discr_from_dd(dd_in)
+    use_wadg = not all(grp.is_affine for grp in discr.groups)
+
+    actx = vec.array_context
+    inv_area_elements = 1./area_element(actx, dcoll, dd=dd_in)
+    if use_wadg:
+        # FIXME: This feels just wrong; can't we be smarter here?
+        return DOFArray(
+            actx,
+            tuple(actx.einsum("ik,km,em,mj,ej->ei",
+                            reference_inverse_mass_matrix(
+                                actx,
+                                element_group=grp
+                            ),
+                            reference_mass_matrix(
+                                actx,
+                                out_element_group=grp,
+                                in_element_group=grp
+                            ),
+                            iae_i,
+                            reference_inverse_mass_matrix(
+                                actx,
+                                element_group=grp
+                            ),
+                            vec_i,
+                            # Ew. Fix this `mass_inv_mat_x` business
+                            arg_names=(
+                                "mass_inv_mat_1", "mass_mat",
+                                "jac_det_inv", "mass_inv_mat_2", "vec"
+                            ),
+                            tagged=(MassOperatorTag(),))
+
+                for grp, iae_i, vec_i in zip(discr.groups, inv_area_elements, vec)
+            )
+        )
+    else:
+        return DOFArray(
+            actx,
+            tuple(actx.einsum("ij,ej,ej->ei",
+                            reference_inverse_mass_matrix(
+                                actx,
+                                element_group=grp
+                            ),
+                            iae_i,
+                            vec_i,
+                            arg_names=("mass_inv_mat", "jac_det_inv", "vec"),
+                            tagged=(MassOperatorTag(),),
+                            tagged_array_axes={
+                                "mass_inv_mat": "stride:auto,stride:auto"
+                            })
+
+                for grp, iae_i, vec_i in zip(discr.groups, inv_area_elements, vec)
+            )
+        )
+
+
+def inverse_mass_operator(dcoll, vec):
+    if isinstance(vec, np.ndarray):
+        return obj_array_vectorize(
+            lambda el: inverse_mass_operator(dcoll, el), vec
+        )
+
+    return _apply_inverse_mass_operator(
+        dcoll, dof_desc.DD_VOLUME, dof_desc.DD_VOLUME, vec
+    )
+
+
 @memoize_on_first_arg
 def _bound_face_mass(dcoll, dd):
     u = sym.Variable("u", dd=dd)
-- 
GitLab