From 5f1a7c319638eb9d1a74c3af284a2519d7fadccc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 8 Oct 2015 12:37:36 -0500 Subject: [PATCH] More compiler hackery --- grudge/symbolic/compiler.py | 106 +++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index e8b0abb..8dfbc85 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -266,6 +266,62 @@ class FluxExchangeBatchAssign(Instruction): def get_executor_method(self, executor): return executor.exec_flux_exchange_batch_assign + +class VectorExprAssign(Assign): + __slots__ = ["compiled"] + + def get_executor_method(self, executor): + return executor.exec_vector_expr_assign + + comment = "compiled" + + @memoize_method + def compiled(self, executor): + discr = executor.discr + + from hedge.backends.vector_expr import \ + VectorExpressionInfo, simple_result_dtype_getter + from hedge.backends.cuda.vector_expr import CompiledVectorExpression + return CompiledVectorExpression( + [VectorExpressionInfo( + name=name, + expr=expr, + do_not_return=dnr) + for name, expr, dnr in zip( + self.names, self.exprs, self.do_not_return)], + result_dtype_getter=simple_result_dtype_getter, + allocator=discr.pool.allocate) + + +class CUDAFluxBatchAssign(FluxBatchAssign): + """ + :ivar quadrature_tag: + """ + + @memoize_method + def get_dependencies(self): + deps = set() + for wdflux in self.expressions: + deps |= set(wdflux.interior_deps) + deps |= set(wdflux.boundary_deps) + + dep_mapper = self.dep_mapper_factory() + + from pytools import flatten + return set(flatten(dep_mapper(dep) for dep in deps)) + + @memoize_method + def kernel(self, executor): + discr = executor.discr + if self.quadrature_tag is None: + flux_plan = discr.flux_plan + else: + flux_plan = discr.get_cuda_quadrature_info(self.quadrature_tag).flux_plan + + return flux_plan.make_kernel( + executor.discr, executor, self.expressions) + + # }}} @@ -582,8 +638,19 @@ class OperatorCompiler(IdentityMapper): instances with fields `flux_expr` and `dependencies`. """ - # overridden by subclasses - raise NotImplementedError + from grudge.symbolic.mappers import FluxCollector + contained_flux_ops = FluxCollector()(expr) + + from pytools import all + assert all(isinstance(op, sym.WholeDomainFluxOperator) + for op in contained_flux_ops), \ + "not all flux operators were of the expected type" + + return [self.FluxRecord( + flux_expr=wdflux, + dependencies=set(wdflux.interior_deps) | set(wdflux.boundary_deps), + repr_op=wdflux.repr_op()) + for wdflux in contained_flux_ops] def collect_diff_ops(self, expr): from grudge.symbolic.operators import ReferenceDiffOperatorBase @@ -896,6 +963,18 @@ class OperatorCompiler(IdentityMapper): def make_flux_batch_assign(self, names, expressions, repr_op): return FluxBatchAssign(names=names, expressions=expressions, repr_op=repr_op) + # from CUDA backend: + # def make_flux_batch_assign(self, names, expressions, repr_op): + # from pytools import single_valued + # quadrature_tag = single_valued( + # wdflux.quadrature_tag + # for wdflux in expressions) + + # return CUDAFluxBatchAssign( + # names=names, expressions=expressions, repr_op=repr_op, + # dep_mapper_factory=self.dep_mapper_factory, + # quadrature_tag=quadrature_tag) + # }}} # {{{ assignment aggregration pass @@ -1102,6 +1181,29 @@ class OperatorCompiler(IdentityMapper): # }}} + def internal_map_flux(self, wdflux): + from hedge.optemplate.operators import WholeDomainFluxOperator + return WholeDomainFluxOperator( + wdflux.is_lift, + [wdflux.InteriorInfo( + flux_expr=ii.flux_expr, + field_expr=self.rec(ii.field_expr)) + for ii in wdflux.interiors], + [wdflux.BoundaryInfo( + flux_expr=bi.flux_expr, + bpair=self.rec(bi.bpair)) + for bi in wdflux.boundaries], + wdflux.quadrature_tag) + + def map_whole_domain_flux(self, wdflux): + return self.map_planned_flux(wdflux) + + def finalize_multi_assign(self, names, exprs, do_not_return, priority): + return VectorExprAssign(names=names, exprs=exprs, + do_not_return=do_not_return, + dep_mapper_factory=self.dep_mapper_factory, + priority=priority) + # }}} -- GitLab