Skip to content
Snippets Groups Projects
Commit 5f1a7c31 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

More compiler hackery

parent a986a8d9
No related branches found
No related tags found
No related merge requests found
...@@ -266,6 +266,62 @@ class FluxExchangeBatchAssign(Instruction): ...@@ -266,6 +266,62 @@ class FluxExchangeBatchAssign(Instruction):
def get_executor_method(self, executor): def get_executor_method(self, executor):
return executor.exec_flux_exchange_batch_assign return executor.exec_flux_exchange_batch_assign
class VectorExprAssign(Assign):
__slots__ = ["compiled"]
def get_executor_method(self, executor):
return executor.exec_vector_expr_assign
comment = "compiled"
@memoize_method
def compiled(self, executor):
discr = executor.discr
from hedge.backends.vector_expr import \
VectorExpressionInfo, simple_result_dtype_getter
from hedge.backends.cuda.vector_expr import CompiledVectorExpression
return CompiledVectorExpression(
[VectorExpressionInfo(
name=name,
expr=expr,
do_not_return=dnr)
for name, expr, dnr in zip(
self.names, self.exprs, self.do_not_return)],
result_dtype_getter=simple_result_dtype_getter,
allocator=discr.pool.allocate)
class CUDAFluxBatchAssign(FluxBatchAssign):
"""
:ivar quadrature_tag:
"""
@memoize_method
def get_dependencies(self):
deps = set()
for wdflux in self.expressions:
deps |= set(wdflux.interior_deps)
deps |= set(wdflux.boundary_deps)
dep_mapper = self.dep_mapper_factory()
from pytools import flatten
return set(flatten(dep_mapper(dep) for dep in deps))
@memoize_method
def kernel(self, executor):
discr = executor.discr
if self.quadrature_tag is None:
flux_plan = discr.flux_plan
else:
flux_plan = discr.get_cuda_quadrature_info(self.quadrature_tag).flux_plan
return flux_plan.make_kernel(
executor.discr, executor, self.expressions)
# }}} # }}}
...@@ -582,8 +638,19 @@ class OperatorCompiler(IdentityMapper): ...@@ -582,8 +638,19 @@ class OperatorCompiler(IdentityMapper):
instances with fields `flux_expr` and `dependencies`. instances with fields `flux_expr` and `dependencies`.
""" """
# overridden by subclasses from grudge.symbolic.mappers import FluxCollector
raise NotImplementedError contained_flux_ops = FluxCollector()(expr)
from pytools import all
assert all(isinstance(op, sym.WholeDomainFluxOperator)
for op in contained_flux_ops), \
"not all flux operators were of the expected type"
return [self.FluxRecord(
flux_expr=wdflux,
dependencies=set(wdflux.interior_deps) | set(wdflux.boundary_deps),
repr_op=wdflux.repr_op())
for wdflux in contained_flux_ops]
def collect_diff_ops(self, expr): def collect_diff_ops(self, expr):
from grudge.symbolic.operators import ReferenceDiffOperatorBase from grudge.symbolic.operators import ReferenceDiffOperatorBase
...@@ -896,6 +963,18 @@ class OperatorCompiler(IdentityMapper): ...@@ -896,6 +963,18 @@ class OperatorCompiler(IdentityMapper):
def make_flux_batch_assign(self, names, expressions, repr_op): def make_flux_batch_assign(self, names, expressions, repr_op):
return FluxBatchAssign(names=names, expressions=expressions, repr_op=repr_op) return FluxBatchAssign(names=names, expressions=expressions, repr_op=repr_op)
# from CUDA backend:
# def make_flux_batch_assign(self, names, expressions, repr_op):
# from pytools import single_valued
# quadrature_tag = single_valued(
# wdflux.quadrature_tag
# for wdflux in expressions)
# return CUDAFluxBatchAssign(
# names=names, expressions=expressions, repr_op=repr_op,
# dep_mapper_factory=self.dep_mapper_factory,
# quadrature_tag=quadrature_tag)
# }}} # }}}
# {{{ assignment aggregration pass # {{{ assignment aggregration pass
...@@ -1102,6 +1181,29 @@ class OperatorCompiler(IdentityMapper): ...@@ -1102,6 +1181,29 @@ class OperatorCompiler(IdentityMapper):
# }}} # }}}
def internal_map_flux(self, wdflux):
from hedge.optemplate.operators import WholeDomainFluxOperator
return WholeDomainFluxOperator(
wdflux.is_lift,
[wdflux.InteriorInfo(
flux_expr=ii.flux_expr,
field_expr=self.rec(ii.field_expr))
for ii in wdflux.interiors],
[wdflux.BoundaryInfo(
flux_expr=bi.flux_expr,
bpair=self.rec(bi.bpair))
for bi in wdflux.boundaries],
wdflux.quadrature_tag)
def map_whole_domain_flux(self, wdflux):
return self.map_planned_flux(wdflux)
def finalize_multi_assign(self, names, exprs, do_not_return, priority):
return VectorExprAssign(names=names, exprs=exprs,
do_not_return=do_not_return,
dep_mapper_factory=self.dep_mapper_factory,
priority=priority)
# }}} # }}}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment