From 73c0b9396500f21f64019f3fd232258f15008ef1 Mon Sep 17 00:00:00 2001
From: "[6~" <inform@tiker.net>
Date: Fri, 22 May 2020 00:14:48 -0500
Subject: [PATCH] Switch eager example over to weak DG

---
 examples/wave/wave-eager.py | 17 ++++++-----------
 grudge/eager.py             | 11 +++++++++++
 2 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/examples/wave/wave-eager.py b/examples/wave/wave-eager.py
index f80ff2c7..aa9e87ff 100644
--- a/examples/wave/wave-eager.py
+++ b/examples/wave/wave-eager.py
@@ -69,12 +69,7 @@ def wave_flux(discr, c, w_tpair):
             0.5*normal_times(v_jump),
             )
 
-    flux_strong = join_fields(
-            np.dot(v.int, normal),
-            normal_times(u.int),
-            ) - flux_weak
-
-    return discr.interp(w_tpair.dd, "all_faces", c*flux_strong)
+    return discr.interp(w_tpair.dd, "all_faces", c*flux_weak)
 
 
 def wave_operator(discr, c, w):
@@ -87,12 +82,12 @@ def wave_operator(discr, c, w):
     dir_bc = join_fields(-dir_u, dir_v)
 
     return (
-            - join_fields(
-                -c*discr.div(v),
-                -c*discr.grad(u)
-                )
-            +  # noqa: W504
             discr.inverse_mass(
+                join_fields(
+                    c*discr.weak_div(v),
+                    c*discr.weak_grad(u)
+                    )
+                -  # noqa: W504
                 discr.face_mass(
                     wave_flux(discr, c=c, w_tpair=interior_trace_pair(discr, w))
                     + wave_flux(discr, c=c, w_tpair=TracePair(
diff --git a/grudge/eager.py b/grudge/eager.py
index 5cf380f8..e64e2fa8 100644
--- a/grudge/eager.py
+++ b/grudge/eager.py
@@ -66,6 +66,17 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries):
         return sum(
                 self.grad(vec_i)[i] for i, vec_i in enumerate(vecs))
 
+    @memoize_method
+    def _bound_weak_grad(self):
+        return bind(self, sym.stiffness_t(self.dim) * sym.Variable("u"))
+
+    def weak_grad(self, vec):
+        return self._bound_weak_grad()(vec.queue, u=vec)
+
+    def weak_div(self, vecs):
+        return sum(
+                self.weak_grad(vec_i)[i] for i, vec_i in enumerate(vecs))
+
     @memoize_method
     def normal(self, dd):
         with cl.CommandQueue(self.cl_context) as queue:
-- 
GitLab