diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index 972e4056729808778e5875ba786ef25c1348f362..53db42f18665fa37582c6e92afc5b214916f8e66 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -823,26 +823,16 @@ def aggregate_assignments(inf_mapper, instructions, result, # {{{ to-loopy mapper -def set_once(d, k, v): - try: - v_prev = d[k] - except KeyError: - d[k] = v - else: - assert v_prev == d[k] - - class ToLoopyExpressionMapper(mappers.IdentityMapper): - def __init__(self, dd_inference_mapper, output_names, temp_names, iname): + def __init__(self, dd_inference_mapper, temp_names, iname): self.dd_inference_mapper = dd_inference_mapper - self.output_names = output_names self.temp_names = temp_names self.iname = iname from pymbolic import var self.iname_expr = var(iname) - self.input_mappings = {} - self.output_mappings = {} + self.expr_to_name = {} + self.used_names = set() self.non_scalar_vars = [] def map_name(self, name): @@ -852,31 +842,48 @@ class ToLoopyExpressionMapper(mappers.IdentityMapper): else: return name - def map_variable_reference(self, name, expr): + def map_variable_ref_expr(self, expr, name_prefix): from pymbolic import var dd = self.dd_inference_mapper(expr) - mapped_name = self.map_name(name) - if name in self.output_names: - set_once(self.output_mappings, name, expr) - else: - set_once(self.input_mappings, mapped_name, expr) + try: + name = self.expr_to_name[expr] + except KeyError: + name_prefix = self.map_name(name_prefix) + name = name_prefix + + suffix_nr = 0 + while name in self.used_names: + name = "%s_%s" % (name_prefix, suffix_nr) + suffix_nr += 1 + self.used_names.add(name) + + self.expr_to_name[expr] = name from grudge.symbolic.primitives import DTAG_SCALAR if dd.domain_tag == DTAG_SCALAR or name in self.temp_names: - return var(mapped_name) + return var(name) else: self.non_scalar_vars.append(name) - return var(mapped_name)[self.iname_expr] + return var(name)[self.iname_expr] def map_variable(self, expr): - return self.map_variable_reference(expr.name, expr) + return self.map_variable_ref_expr(expr, expr.name) def map_grudge_variable(self, expr): - return self.map_variable_reference(expr.name, expr) + return self.map_variable_ref_expr(expr, expr.name) def map_subscript(self, expr): - return self.map_variable_reference(expr.aggregate.name, expr) + subscript = expr.index + if isinstance(subscript, tuple): + assert len(subscript) == 1 + subscript, = subscript + + assert isinstance(subscript, int) + + return self.map_variable_ref_expr( + expr, + "%s_%d" % (expr.aggregate.name, subscript)) def map_call(self, expr): if isinstance(expr.function, sym.CFunction): @@ -896,11 +903,8 @@ class ToLoopyExpressionMapper(mappers.IdentityMapper): return 1 def map_node_coordinate_component(self, expr): - mapped_name = "grdg_ncc%d" % expr.axis - set_once(self.input_mappings, mapped_name, expr) - - from pymbolic import var - return var(mapped_name)[self.iname_expr] + return self.map_variable_ref_expr( + expr, "grdg_ncc%d" % expr.axis) def map_common_subexpression(self, expr): raise ValueError("not expecting CSEs at this stage in the " @@ -978,7 +982,7 @@ class ToLoopyInstructionMapper(object): if dnr] expr_mapper = ToLoopyExpressionMapper( - self.dd_inference_mapper, insn.names, temp_names, iname) + self.dd_inference_mapper, temp_names, iname) insns = [] import loopy as lp @@ -1019,11 +1023,33 @@ class ToLoopyInstructionMapper(object): knl = lp.register_function_manglers(knl, [bessel_function_mangler]) + input_mappings = {} + output_mappings = {} + + from grudge.symbolic.mappers import DependencyMapper + dep_mapper = DependencyMapper(composite_leaves=False) + + for expr, name in six.iteritems(expr_mapper.expr_to_name): + deps = dep_mapper(expr) + assert len(deps) <= 1 + if not deps: + is_output = False + else: + dep, = deps + is_output = dep.name in insn.names + + if is_output: + tgt_dict = output_mappings + else: + tgt_dict = input_mappings + + tgt_dict[name] = expr + return LoopyKernelInstruction( LoopyKernelDescriptor( loopy_kernel=knl, - input_mappings=expr_mapper.input_mappings, - output_mappings=expr_mapper.output_mappings, + input_mappings=input_mappings, + output_mappings=output_mappings, fixed_arguments={}, governing_dd=governing_dd) )