From 91833734201d7bbb4f43dba90877706b3cdb035c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 18 Feb 2021 11:34:20 -0600
Subject: [PATCH] Make reference diff use elwise_linear loopy kernel

---
 grudge/execution.py | 104 +++++++++++++++++---------------------------
 1 file changed, 39 insertions(+), 65 deletions(-)

diff --git a/grudge/execution.py b/grudge/execution.py
index 59e32a0d..dba0648d 100644
--- a/grudge/execution.py
+++ b/grudge/execution.py
@@ -25,7 +25,6 @@ from numbers import Number
 import numpy as np
 
 from pytools import memoize_in
-from pytools.obj_array import make_obj_array
 
 import loopy as lp
 import pyopencl as cl
@@ -254,13 +253,7 @@ class ExecutionMapper(mappers.Evaluator,
         raise NotImplementedError(
                 "differentiation should be happening in batched form")
 
-    def map_elementwise_linear(self, op, field_expr):
-        field = self.rec(field_expr)
-
-        from grudge.tools import is_zero
-        if is_zero(field):
-            return 0
-
+    def _elwise_linear_loopy_prg(self):
         @memoize_in(self.array_context, (ExecutionMapper, "elwise_linear_knl"))
         def prg():
             result = make_loopy_program(
@@ -269,18 +262,27 @@ class ExecutionMapper(mappers.Evaluator,
                     0<=idof<ndiscr_nodes_out and
                     0<=j<ndiscr_nodes_in}""",
                 "result[iel, idof] = sum(j, mat[idof, j] * vec[iel, j])",
-                name="diff")
+                name="elwise_linear")
 
             result = lp.tag_array_axes(result, "mat", "stride:auto,stride:auto")
             return result
 
+        return prg()
+
+    def map_elementwise_linear(self, op, field_expr):
+        field = self.rec(field_expr)
+
+        from grudge.tools import is_zero
+        if is_zero(field):
+            return 0
+
         in_discr = self.discrwb.discr_from_dd(op.dd_in)
         out_discr = self.discrwb.discr_from_dd(op.dd_out)
 
         result = out_discr.empty(self.array_context, dtype=field.entry_dtype)
 
+        prg = self._elwise_linear_loopy_prg()
         for in_grp, out_grp in zip(in_discr.groups, out_discr.groups):
-
             cache_key = "elwise_linear", in_grp, out_grp, op, field.entry_dtype
             try:
                 matrix = self.bound_op.operator_data_cache[cache_key]
@@ -294,7 +296,7 @@ class ExecutionMapper(mappers.Evaluator,
                 self.bound_op.operator_data_cache[cache_key] = matrix
 
             self.array_context.call_loopy(
-                    prg(),
+                    prg,
                     mat=matrix,
                     result=result[out_grp.index],
                     vec=field[in_grp.index])
@@ -485,70 +487,42 @@ class ExecutionMapper(mappers.Evaluator,
     def map_insn_diff_batch_assign(self, insn, profile_data=None):
         field = self.rec(insn.field)
         repr_op = insn.operators[0]
-        # FIXME: There's no real reason why differentiation is special,
-        # execution-wise.
-        # This should be unified with map_elementwise_linear, which should
-        # be extended to support batching.
 
         assert repr_op.dd_in.domain_tag == repr_op.dd_out.domain_tag
 
-        @memoize_in(self.array_context,
-                (ExecutionMapper, "reference_derivative_prg"))
-        def prg(nmatrices):
-            result = make_loopy_program(
-                """{[imatrix, iel, idof, j]:
-                    0<=imatrix<nmatrices and
-                    0<=iel<nelements and
-                    0<=idof<nunit_nodes_out and
-                    0<=j<nunit_nodes_in}""",
-                """
-                result[imatrix, iel, idof] = sum(
-                        j, diff_mat[imatrix, idof, j] * vec[iel, j])
-                """,
-                name="diff")
+        in_discr = self.discrwb.discr_from_dd(repr_op.dd_in)
+        out_discr = self.discrwb.discr_from_dd(repr_op.dd_out)
 
-            result = lp.fix_parameters(result, nmatrices=nmatrices)
-            result = lp.tag_inames(result, "imatrix: unr")
-            result = lp.tag_array_axes(result, "result", "sep,c,c")
-            return result
+        prg = self._elwise_linear_loopy_prg()
 
-        noperators = len(insn.operators)
+        result = []
+        for name, op in zip(insn.names, insn.operators):
+            group_results = []
+            for in_grp, out_grp in zip(in_discr.groups, out_discr.groups):
+                assert in_grp.nelements == out_grp.nelements
 
-        in_discr = self.discrwb.discr_from_dd(repr_op.dd_in)
-        out_discr = self.discrwb.discr_from_dd(repr_op.dd_out)
+                if in_grp.nelements == 0:
+                    continue
 
-        result = make_obj_array([
-            out_discr.empty(self.array_context, dtype=field.entry_dtype)
-            for idim in range(noperators)])
+                # Cache operator
+                cache_key = "diff_batch", in_grp, out_grp, tuple(insn.operators),\
+                    field.entry_dtype
+                try:
+                    matrices_dev = self.bound_op.operator_data_cache[cache_key]
+                except KeyError:
+                    matrices_dev = [self.array_context.from_numpy(mat)
+                            for mat in repr_op.matrices(out_grp, in_grp)]
+                    self.bound_op.operator_data_cache[cache_key] = matrices_dev
 
-        for in_grp, out_grp in zip(in_discr.groups, out_discr.groups):
-            if in_grp.nelements == 0:
-                continue
+                group_results.append(self.array_context.call_loopy(
+                        prg,
+                        mat=matrices_dev[op.rst_axis],
+                        vec=field[in_grp.index])["result"])
 
-            # Cache operator
-            cache_key = "diff_batch", in_grp, out_grp, tuple(insn.operators),\
-                field.entry_dtype
-            try:
-                matrices_ary_dev = self.bound_op.operator_data_cache[cache_key]
-            except KeyError:
-                matrices = repr_op.matrices(out_grp, in_grp)
-                matrices_ary = np.empty(
-                    (noperators, out_grp.nunit_dofs, in_grp.nunit_dofs),
-                    dtype=field.entry_dtype)
-                for i, op in enumerate(insn.operators):
-                    matrices_ary[i] = matrices[op.rst_axis]
-                matrices_ary_dev = self.array_context.from_numpy(matrices_ary)
-                self.bound_op.operator_data_cache[cache_key] = matrices_ary_dev
+            result.append(
+                    (name, DOFArray(self.array_context, tuple(group_results))))
 
-            self.array_context.call_loopy(
-                    prg(noperators),
-                    diff_mat=matrices_ary_dev,
-                    result=make_obj_array([
-                        result[iop][out_grp.index]
-                        for iop in range(noperators)
-                        ]), vec=field[in_grp.index])
-
-        return [(name, result[i]) for i, name in enumerate(insn.names)], []
+        return result, []
 
     # }}}
 
-- 
GitLab