From c744fbcce9c150840dcd025058ef63b02282f571 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 5 Jul 2017 16:04:41 +0200
Subject: [PATCH] Teach loopy execution about Ones, If, Comparison, mop up
 scalar .get() cases

---
 grudge/execution.py                  |  7 ++++---
 grudge/symbolic/compiler.py          | 14 ++++++++++++--
 grudge/symbolic/dofdesc_inference.py | 14 ++++++++++----
 test/test_grudge.py                  |  9 +++++----
 4 files changed, 31 insertions(+), 13 deletions(-)

diff --git a/grudge/execution.py b/grudge/execution.py
index 42749b50..69e4b7a4 100644
--- a/grudge/execution.py
+++ b/grudge/execution.py
@@ -337,19 +337,20 @@ class ExecutionMapper(mappers.Evaluator,
         for name, expr in six.iteritems(kdescr.input_mappings):
             kwargs[name] = self.rec(expr)
 
-        vdiscr = self.discr.volume_discr
+        discr = self.get_discr(kdescr.governing_dd)
         for name in kdescr.scalar_args():
             v = kwargs[name]
             if isinstance(v, (int, float)):
-                kwargs[name] = vdiscr.real_dtype.type(v)
+                kwargs[name] = discr.real_dtype.type(v)
             elif isinstance(v, complex):
-                kwargs[name] = vdiscr.complex_dtype.type(v)
+                kwargs[name] = discr.complex_dtype.type(v)
             elif isinstance(v, np.number):
                 pass
             else:
                 raise ValueError("unrecognized scalar type for variable '%s': %s"
                         % (name, type(v)))
 
+        kwargs["grdg_n"] = discr.nnodes
         evt, result_dict = kdescr.loopy_kernel(self.queue, **kwargs)
         return list(result_dict.items()), []
 
diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py
index c2f2e5b6..492c3e29 100644
--- a/grudge/symbolic/compiler.py
+++ b/grudge/symbolic/compiler.py
@@ -72,11 +72,12 @@ def _make_dep_mapper(include_subscripts):
 
 class LoopyKernelDescriptor(object):
     def __init__(self, loopy_kernel, input_mappings, output_mappings,
-            fixed_arguments):
+            fixed_arguments, governing_dd):
         self.loopy_kernel = loopy_kernel
         self.input_mappings = input_mappings
         self.output_mappings = output_mappings
         self.fixed_arguments = fixed_arguments
+        self.governing_dd = governing_dd
 
     @memoize_method
     def scalar_args(self):
@@ -852,6 +853,9 @@ class ToLoopyExpressionMapper(mappers.IdentityMapper):
                     "do not know how to map function '%s' into loopy"
                     % expr.function)
 
+    def map_ones(self, expr):
+        return 1
+
     def map_node_coordinate_component(self, expr):
         mapped_name = "grdg_ncc%d" % expr.axis
         set_once(self.input_mappings, mapped_name, expr)
@@ -905,12 +909,18 @@ class ToLoopyInstructionMapper(object):
         knl = lp.set_options(knl, return_dict=True)
         knl = lp.split_iname(knl, iname, 128, outer_tag="g.0", inner_tag="l.0")
 
+        from pytools import single_valued
+        governing_dd = single_valued(
+                self.dd_inference_mapper(expr)
+                for expr in insn.exprs)
+
         return LoopyKernelInstruction(
             LoopyKernelDescriptor(
                 loopy_kernel=knl,
                 input_mappings=expr_mapper.input_mappings,
                 output_mappings=expr_mapper.output_mappings,
-                fixed_arguments={})
+                fixed_arguments={},
+                governing_dd=governing_dd)
             )
 
     def map_insn_assign_to_discr_scoped(self, insn):
diff --git a/grudge/symbolic/dofdesc_inference.py b/grudge/symbolic/dofdesc_inference.py
index 6b8cbebb..9cb54357 100644
--- a/grudge/symbolic/dofdesc_inference.py
+++ b/grudge/symbolic/dofdesc_inference.py
@@ -126,7 +126,7 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin):
         # FIXME: Subscript has same type as aggregate--a bit weird
         return self.rec(expr.aggregate)
 
-    def map_arithmetic(self, expr, children):
+    def map_multi_child(self, expr, children):
         dofdesc = None
 
         for ch in children:
@@ -138,15 +138,21 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin):
             return dofdesc
 
     def map_sum(self, expr):
-        return self.map_arithmetic(expr, expr.children)
+        return self.map_multi_child(expr, expr.children)
 
     map_product = map_sum
 
     def map_quotient(self, expr):
-        return self.map_arithmetic(expr, (expr.numerator, expr.denominator))
+        return self.map_multi_child(expr, (expr.numerator, expr.denominator))
 
     def map_power(self, expr):
-        return self.map_arithmetic(expr, (expr.base, expr.exponent))
+        return self.map_multi_child(expr, (expr.base, expr.exponent))
+
+    def map_if(self, expr):
+        return self.map_multi_child(expr, [expr.condition, expr.then, expr.else_])
+
+    def map_comparison(self, expr):
+        return self.map_multi_child(expr, [expr.left, expr.right])
 
     def map_nodal_sum(self, expr, enclosing_prec):
         return DOFDesc(DTAG_SCALAR)
diff --git a/test/test_grudge.py b/test/test_grudge.py
index 33725674..214f584c 100644
--- a/test/test_grudge.py
+++ b/test/test_grudge.py
@@ -106,9 +106,9 @@ def test_1d_mass_mat_trig(ctx_factory):
 
     mass_op = bind(discr, sym.MassOperator()(sym.var("f")))
 
-    num_integral_1 = np.dot(ones.get(), mass_op(queue, f=f).get())
-    num_integral_2 = np.dot(f.get(), mass_op(queue, f=ones).get())
-    num_integral_3 = bind(discr, sym.integral(sym.var("f")))(queue, f=f).get()
+    num_integral_1 = np.dot(ones.get(), mass_op(queue, f=f))
+    num_integral_2 = np.dot(f.get(), mass_op(queue, f=ones))
+    num_integral_3 = bind(discr, sym.integral(sym.var("f")))(queue, f=f)
 
     true_integral = 13*np.pi/2
     err_1 = abs(num_integral_1-true_integral)
@@ -211,6 +211,7 @@ def test_2d_gauss_theorem(ctx_factory):
 @pytest.mark.parametrize("op_type", ["strong", "weak"])
 @pytest.mark.parametrize("flux_type", ["upwind"])
 @pytest.mark.parametrize("order", [3, 4, 5])
+# test: 'test_convergence_advec(cl._csc, "disk", [0.1, 0.05], "strong", "upwind", 3)'
 def test_convergence_advec(ctx_factory, mesh_name, mesh_pars, op_type, flux_type,
         order, visualize=False):
     """Test whether 2D advection actually converges"""
@@ -322,7 +323,7 @@ def test_convergence_advec(ctx_factory, mesh_name, mesh_pars, op_type, flux_type
 
         error_l2 = bind(discr,
             sym.norm(2, sym.var("u")-u_analytic(sym.nodes(dim))))(
-                queue, t=last_t, u=last_u).get()
+                queue, t=last_t, u=last_u)
         print(h, error_l2)
         eoc_rec.add_data_point(h, error_l2)
 
-- 
GitLab