From 9de788740ca2a1b84cf0cefa248f021c5f5a1eb1 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 4 Sep 2021 20:09:51 -0500
Subject: [PATCH] Introduce, use inverse_surface_metric_derivative_mat (closes
 gh-142)

---
 grudge/geometry/__init__.py |  2 ++
 grudge/geometry/metrics.py  | 51 +++++++++++++++++++++++++++----------
 grudge/op.py                | 32 ++++++++---------------
 3 files changed, 50 insertions(+), 35 deletions(-)

diff --git a/grudge/geometry/__init__.py b/grudge/geometry/__init__.py
index fec1e595..6bbde3f7 100644
--- a/grudge/geometry/__init__.py
+++ b/grudge/geometry/__init__.py
@@ -32,6 +32,7 @@ from grudge.geometry.metrics import (
     inverse_first_fundamental_form,
 
     inverse_surface_metric_derivative,
+    inverse_surface_metric_derivative_mat,
     pseudoscalar,
     area_element,
 
@@ -51,6 +52,7 @@ __all__ = (
     "inverse_first_fundamental_form",
 
     "inverse_surface_metric_derivative",
+    "inverse_surface_metric_derivative_mat",
     "pseudoscalar",
     "area_element",
 
diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py
index 919bdf46..44c21c78 100644
--- a/grudge/geometry/metrics.py
+++ b/grudge/geometry/metrics.py
@@ -15,6 +15,7 @@ Geometry terms
 --------------
 
 .. autofunction:: inverse_surface_metric_derivative
+.. autofunction:: inverse_surface_metric_derivative_mat
 .. autofunction:: pseudoscalar
 .. autofunction:: area_element
 
@@ -398,21 +399,43 @@ def inverse_surface_metric_derivative(
         dd = DD_VOLUME
     dd = dof_desc.as_dofdesc(dd)
 
-    @memoize_in(dcoll, (inverse_surface_metric_derivative, dd,
-                        rst_axis, xyz_axis))
+    if ambient_dim == dim:
+        return inverse_metric_derivative(
+            actx, dcoll, rst_axis, xyz_axis, dd=dd
+        )
+    else:
+        inv_form1 = inverse_first_fundamental_form(actx, dcoll, dd=dd)
+        return sum(
+            inv_form1[rst_axis, d]*forward_metric_nth_derivative(
+                actx, dcoll, xyz_axis, d, dd=dd
+            ) for d in range(dim))
+
+
+def inverse_surface_metric_derivative_mat(
+        actx: ArrayContext, dcoll: DiscretizationCollection, dd=None):
+    r"""Computes the matrix of inverse surface metric derivatives, indexed by
+    ``(xyz_axis, rst_axis)``. It returns all values of
+    :func:`inverse_surface_metric_derivative_mat` in cached matrix form.
+
+    This function caches its results.
+
+    :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one.
+        Defaults to the base volume discretization.
+    :returns: a :class:`~meshmode.dof_array.DOFArray` containing the
+        inverse metric derivatives in per-group arrays of shape
+        ``(xyz_dimension, rst_dimension, nelements, ndof)``.
+    """
+
+    @memoize_in(dcoll, (inverse_surface_metric_derivative_mat, dd))
     def _inv_surf_metric_deriv():
-        if ambient_dim == dim:
-            imd = inverse_metric_derivative(
-                actx, dcoll, rst_axis, xyz_axis, dd=dd
-            )
-        else:
-            inv_form1 = inverse_first_fundamental_form(actx, dcoll, dd=dd)
-            imd = sum(
-                inv_form1[rst_axis, d]*forward_metric_nth_derivative(
-                    actx, dcoll, xyz_axis, d, dd=dd
-                ) for d in range(dim)
-            )
-        return freeze(imd, actx)
+        mat = actx.np.stack([
+                actx.np.stack(
+                    [inverse_surface_metric_derivative(actx, dcoll,
+                        rst_axis, xyz_axis, dd=dd)
+                        for rst_axis in range(dcoll.dim)])
+                for xyz_axis in range(dcoll.ambient_dim)])
+
+        return freeze(mat, actx)
 
     return thaw(_inv_surf_metric_deriv(), actx)
 
diff --git a/grudge/op.py b/grudge/op.py
index 5e914d8e..3ca89789 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -111,26 +111,23 @@ def reference_derivative_matrices(actx: ArrayContext, element_group):
 
 
 def _compute_local_gradient(dcoll: DiscretizationCollection, vec, xyz_axis):
-    from grudge.geometry import inverse_surface_metric_derivative
+    from grudge.geometry import inverse_surface_metric_derivative_mat
 
     discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME)
     actx = vec.array_context
 
-    inverse_jac_t = actx.np.stack(
-        [inverse_surface_metric_derivative(actx, dcoll, rst_axis, xyz_axis)
-         for rst_axis in range(dcoll.dim)]
-    )
+    inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll)
     return DOFArray(
         actx,
         data=tuple(
             actx.einsum("dei,dij,ej->ei",
-                        inv_jac_t_i,
+                        ijm_i[xyz_axis],
                         reference_derivative_matrices(actx, grp),
                         vec_i,
                         arg_names=("inv_jac_t", "ref_diff_mat", "vec"),
                         tagged=(FirstAxisIsElementsTag(),))
 
-            for grp, vec_i, inv_jac_t_i in zip(discr.groups, vec, inverse_jac_t)
+            for grp, vec_i, ijm_i in zip(discr.groups, vec, inverse_jac_mat)
         )
     )
 
@@ -275,18 +272,15 @@ def reference_stiffness_transpose_matrix(
 def _apply_stiffness_transpose_operator(
         dcoll: DiscretizationCollection, dd_out, dd_in, vec, xyz_axis):
     from grudge.geometry import \
-        inverse_surface_metric_derivative, area_element
+        inverse_surface_metric_derivative_mat, area_element
 
     in_discr = dcoll.discr_from_dd(dd_in)
     out_discr = dcoll.discr_from_dd(dd_out)
 
     actx = vec.array_context
     area_elements = area_element(actx, dcoll, dd=dd_in)
-    inverse_jac_t = actx.np.stack(
-        [inverse_surface_metric_derivative(actx, dcoll,
-                                           rst_axis, xyz_axis, dd=dd_in)
-         for rst_axis in range(dcoll.dim)]
-    )
+    inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in)
+
     return DOFArray(
         actx,
         data=tuple(
@@ -298,17 +292,13 @@ def _apply_stiffness_transpose_operator(
                         ),
                         ae_i,
                         vec_i,
-                        inv_jac_t_i,
+                        ijm_i[xyz_axis],
                         arg_names=("ref_stiffT_mat", "jac", "vec", "inv_jac_t"),
                         tagged=(FirstAxisIsElementsTag(),))
 
-            for out_grp, in_grp, vec_i, ae_i, inv_jac_t_i in zip(out_discr.groups,
-                                                                 in_discr.groups,
-                                                                 vec,
-                                                                 area_elements,
-                                                                 inverse_jac_t)
-        )
-    )
+            for out_grp, in_grp, vec_i, ae_i, ijm_i in zip(
+                out_discr.groups, in_discr.groups, vec, area_elements,
+                inverse_jac_mat)))
 
 
 def weak_local_grad(dcoll: DiscretizationCollection, *args, nested=False):
-- 
GitLab