From 6eb6e9d3df1b24f17b56f6adebb91072056069fe Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 22 Apr 2021 12:34:21 -0500
Subject: [PATCH] Drafty einsum-based mass operator

---
 grudge/op.py                  | 25 +++++++++++++++++++----
 test/test_new_world_grudge.py | 37 ++++++++++++++++++++---------------
 2 files changed, 42 insertions(+), 20 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index c1270df0..66ac1183 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -352,12 +352,24 @@ def reference_mass_matrix(actx, out_element_group, in_element_group):
     )
 
 
-def _apply_mass_operator(dcoll, dd, vec):
-    pass
+def _apply_mass_operator(dcoll, dd_out, dd_in, vec):
 
+    in_discr = dcoll.discr_from_dd(dd_in)
+    out_discr = dcoll.discr_from_dd(dd_out)
+
+    actx = vec.array_context
+    return [
+        actx.np.einsum("ij,j,ej->ei",
+            reference_mass_matrix(actx,
+                out_element_group=out_grp,
+                in_element_group=in_grp),
+            # FIXME: These are not area elements!
+            in_discr.zeros(actx),
+            vec_i)
+        for in_grp, out_grp, vec_i in zip(in_discr.groups, out_discr.groups, vec)]
 
-def mass_operator(dcoll, *args):
 
+def mass_operator(dcoll, *args):
     if len(args) == 1:
         vec, = args
         dd = sym.DOFDesc("vol", sym.QTAG_NONE)
@@ -371,7 +383,12 @@ def mass_operator(dcoll, *args):
             lambda el: mass_operator(dcoll, dd, el), vec
         )
 
-    return _apply_mass_operator(dcoll, dd, vec)
+    dd_in = dd
+    del dd
+    from grudge.dof_desc import QTAG_NONE
+    dd_out = dd_in.with_qtag(QTAG_NONE)
+
+    return _apply_mass_operator(dcoll, dd_out, dd_in, vec)
 
 
 @memoize_on_first_arg
diff --git a/test/test_new_world_grudge.py b/test/test_new_world_grudge.py
index 82bfb3e2..a1fff8dd 100644
--- a/test/test_new_world_grudge.py
+++ b/test/test_new_world_grudge.py
@@ -33,8 +33,10 @@ from meshmode.array_context import (  # noqa
 import meshmode.mesh.generation as mgen
 
 from pytools.obj_array import flat_obj_array, make_obj_array
+import grudge.op as op
 
 from grudge import DiscretizationCollection
+import grudge.dof_desc as dof_desc
 
 import pytest
 
@@ -46,7 +48,7 @@ logger = logging.getLogger(__name__)
 # {{{ mass operator trig integration
 
 @pytest.mark.parametrize("ambient_dim", [1, 2, 3])
-@pytest.mark.parametrize("quad_tag", [sym.QTAG_NONE, "OVSMP"])
+@pytest.mark.parametrize("quad_tag", [dof_desc.QTAG_NONE, "OVSMP"])
 def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag):
     """Check the integral of some trig functions on an interval using the mass
     matrix.
@@ -63,9 +65,9 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag):
     from meshmode.discretization.poly_element import \
         QuadratureSimplexGroupFactory
 
-    dd_quad = sym.DOFDesc(sym.DTAG_VOLUME_ALL, quad_tag)
+    dd_quad = dof_desc.DOFDesc(dof_desc.DTAG_VOLUME_ALL, quad_tag)
 
-    if quad_tag is sym.QTAG_NONE:
+    if quad_tag is dof_desc.QTAG_NONE:
         quad_tag_to_group_factory = {}
     else:
         quad_tag_to_group_factory = {
@@ -85,22 +87,22 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag):
     def f(x):
         return actx.np.sin(x[0])**2
 
-    def ones(x):
-        return actx.np.ones(x.shape)
-
-    volm_disc = dcoll.disc_from_dd(sym.DD_VOLUME)
+    volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME)
     x_volm = thaw(actx, volm_disc.nodes())
-    f_volm = f(x_vol)
-    ones_volm = ones(x_vol)
+    f_volm = f(x_volm)
+    ones_volm = volm_disc.zeros(actx) + 1
 
-    quad_disc = dcoll.disc_from_dd(dd_quad)
+    quad_disc = dcoll.discr_from_dd(dd_quad)
     x_quad = thaw(actx, quad_disc.nodes())
     f_quad = f(x_quad)
-    ones_quad = ones(x_quad)
+    ones_quad = quad_disc.zeros(actx) + 1
 
-    mass_op = bind(discr, sym.MassOperator(dd_quad, sym.DD_VOLUME)(sym_f))
+    #mass_op = bind(discr, dof_desc.MassOperator(dd_quad, dof_desc.DD_VOLUME)(sym_f))
 
-    num_integral_1 = np.dot(ones_volm, actx.to_numpy(flatten(mass_op(f=f_quad))))
+    num_integral_1 = np.dot(ones_volm,
+            actx.to_numpy(flatten(
+                op.mass_operator(dcoll, dd_quad, f_quad))))
+    1/0
     err_1 = abs(num_integral_1 - true_integral)
     assert err_1 < 1e-9, err_1
 
@@ -108,12 +110,15 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag):
     err_2 = abs(num_integral_2 - true_integral)
     assert err_2 < 1.0e-9, err_2
 
-    if quad_tag is sym.QTAG_NONE:
+    if quad_tag is dof_desc.QTAG_NONE:
         # NOTE: `integral` always makes a square mass matrix and
         # `QuadratureSimplexGroupFactory` does not have a `mass_matrix` method.
         num_integral_3 = bind(discr,
-                sym.integral(sym_f, dd=dd_quad))(f=f_quad)
+                dof_desc.integral(sym_f, dd=dd_quad))(f=f_quad)
         err_3 = abs(num_integral_3 - true_integral)
         assert err_3 < 5.0e-10, err_3
 
-# }}}
\ No newline at end of file
+# }}}
+
+
+# vim: foldmethod=marker
-- 
GitLab