diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 0bc3d5bc284cb3ae67e744e5cff3f196f9140ec8..1379d726f41105c8d1cf6afa748f938a01d546ae 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1974,6 +1974,7 @@ class ScopedFunctionCollector(CombineMapper): map_variable = map_constant map_function_symbol = map_constant + map_tagged_variable = map_constant def scope_functions(kernel): @@ -1997,9 +1998,19 @@ def scope_functions(kernel): raise NotImplementedError("scope_functions not implemented for %s" % type(insn)) + scoped_substitutions = {} + + for name, rule in kernel.substitutions.items(): + scoped_rule = rule.copy( + expression=function_scoper(rule.expression)) + scoped_substitutions[name] = scoped_rule + scoped_functions.update(scoped_function_collector(scoped_rule.expression)) + # Need to combine the scoped functions into a dict scoped_function_dict = dict(scoped_functions) - return kernel.copy(instructions=new_insns, scoped_functions=scoped_function_dict) + return kernel.copy(instructions=new_insns, + scoped_functions=scoped_function_dict, + substitutions=scoped_substitutions) # }}} diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 9111aebab47e8695499cffb48d01cd9b4e1589d8..852b9ee1d7d1e9691f7d4cc457cf30ddf717c230 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -29,10 +29,13 @@ from six.moves import zip from pytools import ImmutableRecord from loopy.diagnostic import LoopyError -from loopy.kernel.instruction import (MultiAssignmentBase, CInstruction, - _DataObliviousInstruction) +from pymbolic.primitives import Variable +from loopy.symbolic import parse_tagged_name -from loopy.symbolic import IdentityMapper, ScopedFunction + +from loopy.symbolic import (IdentityMapper, ScopedFunction, + SubstitutionRuleMappingContext, RuleAwareIdentityMapper, + SubstitutionRuleExpander) # {{{ argument descriptors @@ -654,49 +657,82 @@ def next_indexed_name(name): num=int(match.group('num'))+1) -class ScopedFunctionNameChanger(IdentityMapper): +class ScopedFunctionNameChanger(RuleAwareIdentityMapper): """ Mapper that takes in a mapping `expr_to_new_names` and maps the corresponding expression to the new names, which correspond to the names in `kernel.scoped_functions`. """ - def __init__(self, expr_to_new_names): + def __init__(self, rule_mapping_context, expr_to_new_names, subst_expander): + super(ScopedFunctionNameChanger, self).__init__(rule_mapping_context) self.expr_to_new_names = expr_to_new_names - - def map_call(self, expr): - if expr in self.expr_to_new_names: - return type(expr)( - ScopedFunction(self.expr_to_new_names[expr]), - tuple(self.rec(child) - for child in expr.parameters)) + self.subst_expander = subst_expander + + def map_call(self, expr, expn_state): + if not isinstance(expr.function, Variable): + return IdentityMapper.map_call(self, expr, expn_state) + + name, tag = parse_tagged_name(expr.function) + + if name not in self.rule_mapping_context.old_subst_rules: + expanded_expr = self.subst_expander(expr) + if expr in self.expr_to_new_names: + return type(expr)( + ScopedFunction(self.expr_to_new_names[expr]), + tuple(self.rec(child) + for child in expr.parameters)) + elif expanded_expr in self.expr_to_names: + return type(expr)( + ScopedFunction(self.expr_to_new_names[expanded_expr]), + tuple(self.rec(child) + for child in expr.parameters)) + else: + return IdentityMapper.map_call(self, expr) else: - return IdentityMapper.map_call(self, expr) + return self.map_substitution(name, tag, expr.parameters, expn_state) - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs(self, expr, expn_state): + expanded_expr = self.subst_expander(expr) if expr in self.expr_to_new_names: return type(expr)( ScopedFunction(self.expr_to_new_names[expr]), - tuple(self.rec(child) + tuple(self.rec(child, expn_state) + for child in expr.parameters), + dict( + (key, self.rec(val, expn_state)) + for key, val in six.iteritems(expr.kw_parameters)) + ) + elif expanded_expr in self.expr_to_names: + return type(expr)( + ScopedFunction(self.expr_to_new_names[expanded_expr]), + tuple(self.rec(child, expn_state) for child in expr.parameters), dict( - (key, self.rec(val)) + (key, self.rec(val, expn_state)) for key, val in six.iteritems(expr.kw_parameters)) ) else: return IdentityMapper.map_call_with_kwargs(self, expr) - def map_reduction(self, expr): + def map_reduction(self, expr, expn_state): from loopy.symbolic import Reduction + expanded_expr = self.subst_expander(expr) - if self.expr_to_new_names: + if expr in self.expr_to_new_names: return Reduction( ScopedFunction(self.expr_to_new_names[expr]), tuple(expr.inames), - self.rec(expr.expr), + self.rec(expr.expr, expn_state), + allow_simultaneous=expr.allow_simultaneous) + elif expanded_expr in self.expr_to_new_names: + return Reduction( + ScopedFunction(self.expr_to_new_names[expanded_expr]), + tuple(expr.inames), + self.rec(expr.expr, expn_state), allow_simultaneous=expr.allow_simultaneous) else: - return IdentityMapper.map_reduction(self, expr) + return IdentityMapper.map_reduction(self, expr, expn_state) def register_pymbolic_calls_to_knl_callables(kernel, @@ -741,19 +777,14 @@ def register_pymbolic_calls_to_knl_callables(kernel, # Using the data populated in pymbolic_calls_to_new_names to change the # names of the scoped functions of all the calls in the kernel. - new_insns = [] - scope_changer = ScopedFunctionNameChanger(pymbolic_calls_to_new_names) - for insn in kernel.instructions: - if isinstance(insn, (MultiAssignmentBase, CInstruction)): - expr = scope_changer(insn.expression) - new_insns.append(insn.copy(expression=expr)) - elif isinstance(insn, _DataObliviousInstruction): - new_insns.append(insn) - else: - raise NotImplementedError("Type Inference Specialization not" - "implemented for %s instruciton" % type(insn)) - return kernel.copy(scoped_functions=scoped_names_to_functions, - instructions=new_insns) + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + subst_expander = SubstitutionRuleExpander(kernel.substitutions) + scope_changer = ScopedFunctionNameChanger(rule_mapping_context, + pymbolic_calls_to_new_names, subst_expander) + scoped_kernel = scope_changer.map_kernel(kernel) + + return scoped_kernel.copy(scoped_functions=scoped_names_to_functions) # }}}