From 5f1a7c319638eb9d1a74c3af284a2519d7fadccc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 8 Oct 2015 12:37:36 -0500
Subject: [PATCH] More compiler hackery

---
 grudge/symbolic/compiler.py | 106 +++++++++++++++++++++++++++++++++++-
 1 file changed, 104 insertions(+), 2 deletions(-)

diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py
index e8b0abb..8dfbc85 100644
--- a/grudge/symbolic/compiler.py
+++ b/grudge/symbolic/compiler.py
@@ -266,6 +266,62 @@ class FluxExchangeBatchAssign(Instruction):
     def get_executor_method(self, executor):
         return executor.exec_flux_exchange_batch_assign
 
+
+class VectorExprAssign(Assign):
+    __slots__ = ["compiled"]
+
+    def get_executor_method(self, executor):
+        return executor.exec_vector_expr_assign
+
+    comment = "compiled"
+
+    @memoize_method
+    def compiled(self, executor):
+        discr = executor.discr
+
+        from hedge.backends.vector_expr import \
+                VectorExpressionInfo, simple_result_dtype_getter
+        from hedge.backends.cuda.vector_expr import CompiledVectorExpression
+        return CompiledVectorExpression(
+                [VectorExpressionInfo(
+                    name=name,
+                    expr=expr,
+                    do_not_return=dnr)
+                    for name, expr, dnr in zip(
+                        self.names, self.exprs, self.do_not_return)],
+                result_dtype_getter=simple_result_dtype_getter,
+                allocator=discr.pool.allocate)
+
+
+class CUDAFluxBatchAssign(FluxBatchAssign):
+    """
+    :ivar quadrature_tag:
+    """
+
+    @memoize_method
+    def get_dependencies(self):
+        deps = set()
+        for wdflux in self.expressions:
+            deps |= set(wdflux.interior_deps)
+            deps |= set(wdflux.boundary_deps)
+
+        dep_mapper = self.dep_mapper_factory()
+
+        from pytools import flatten
+        return set(flatten(dep_mapper(dep) for dep in deps))
+
+    @memoize_method
+    def kernel(self, executor):
+        discr = executor.discr
+        if self.quadrature_tag is None:
+            flux_plan = discr.flux_plan
+        else:
+            flux_plan = discr.get_cuda_quadrature_info(self.quadrature_tag).flux_plan
+
+        return flux_plan.make_kernel(
+                executor.discr, executor, self.expressions)
+
+
 # }}}
 
 
@@ -582,8 +638,19 @@ class OperatorCompiler(IdentityMapper):
         instances with fields `flux_expr` and `dependencies`.
         """
 
-        # overridden by subclasses
-        raise NotImplementedError
+        from grudge.symbolic.mappers import FluxCollector
+        contained_flux_ops = FluxCollector()(expr)
+
+        from pytools import all
+        assert all(isinstance(op, sym.WholeDomainFluxOperator)
+                for op in contained_flux_ops), \
+                        "not all flux operators were of the expected type"
+
+        return [self.FluxRecord(
+            flux_expr=wdflux,
+            dependencies=set(wdflux.interior_deps) | set(wdflux.boundary_deps),
+            repr_op=wdflux.repr_op())
+            for wdflux in contained_flux_ops]
 
     def collect_diff_ops(self, expr):
         from grudge.symbolic.operators import ReferenceDiffOperatorBase
@@ -896,6 +963,18 @@ class OperatorCompiler(IdentityMapper):
     def make_flux_batch_assign(self, names, expressions, repr_op):
         return FluxBatchAssign(names=names, expressions=expressions, repr_op=repr_op)
 
+    # from CUDA backend:
+    # def make_flux_batch_assign(self, names, expressions, repr_op):
+    #     from pytools import single_valued
+    #     quadrature_tag = single_valued(
+    #             wdflux.quadrature_tag
+    #             for wdflux in expressions)
+
+    #     return CUDAFluxBatchAssign(
+    #             names=names, expressions=expressions, repr_op=repr_op,
+    #             dep_mapper_factory=self.dep_mapper_factory,
+    #             quadrature_tag=quadrature_tag)
+
     # }}}
 
     # {{{ assignment aggregration pass
@@ -1102,6 +1181,29 @@ class OperatorCompiler(IdentityMapper):
 
     # }}}
 
+    def internal_map_flux(self, wdflux):
+        from hedge.optemplate.operators import WholeDomainFluxOperator
+        return WholeDomainFluxOperator(
+            wdflux.is_lift,
+            [wdflux.InteriorInfo(
+                flux_expr=ii.flux_expr,
+                field_expr=self.rec(ii.field_expr))
+                for ii in wdflux.interiors],
+            [wdflux.BoundaryInfo(
+                flux_expr=bi.flux_expr,
+                bpair=self.rec(bi.bpair))
+                for bi in wdflux.boundaries],
+            wdflux.quadrature_tag)
+
+    def map_whole_domain_flux(self, wdflux):
+        return self.map_planned_flux(wdflux)
+
+    def finalize_multi_assign(self, names, exprs, do_not_return, priority):
+        return VectorExprAssign(names=names, exprs=exprs,
+                do_not_return=do_not_return,
+                dep_mapper_factory=self.dep_mapper_factory,
+                priority=priority)
+
 # }}}
 
 
-- 
GitLab