diff --git a/loopy/subst.py b/loopy/subst.py index 412623c90ee1bc981b0d56a1738a54fe93005730..3b112a4fd44880f6cd66019aff1023ee426ae125 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -202,7 +202,7 @@ def extract_subst(kernel, subst_name, template, parameters=()): class TemporaryToSubstChanger(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, temp_name, definition_insn_ids, - usage_to_definition, within): + usage_to_definition, extra_arguments, within): self.var_name_gen = rule_mapping_context.make_unique_var_name super(TemporaryToSubstChanger, self).__init__(rule_mapping_context) @@ -211,6 +211,9 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): self.definition_insn_ids = definition_insn_ids self.usage_to_definition = usage_to_definition + from pymbolic import var + self.extra_arguments = tuple(var(arg) for arg in extra_arguments) + self.within = within self.definition_insn_id_to_subst_name = {} @@ -248,7 +251,7 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): expr, expn_state) def transform_access(self, index, expn_state): - my_insn_id = expn_state.stack[0][0] + my_insn_id = expn_state.insn_id if my_insn_id in self.definition_insn_ids: return None @@ -259,10 +262,14 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): self.saw_unmatched_usage_sites[my_def_id] = True return None - my_insn_id = expn_state.stack[0][0] - subst_name = self.get_subst_name(my_def_id) + if self.extra_arguments: + if index is None: + index = self.extra_arguments + else: + index = index + self.extra_arguments + from pymbolic import var if index is None: return var(subst_name) @@ -270,9 +277,11 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): return var(subst_name)(*index) -def temporary_to_subst(kernel, temp_name, within=None): +def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): """Extract an assignment to a temporary variable - as a :ref:`substituion-rule`. The temporary may + as a :ref:`substituiton-rule`. The temporary may be an array, in + which case the array indices will become arguments to the substitution + rule. :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. @@ -284,6 +293,9 @@ def temporary_to_subst(kernel, temp_name, within=None): as the temporary variable is left in place. """ + if isinstance(extra_arguments, str): + extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) + # {{{ establish the relevant definition of temp_name for each usage site dep_kernel = expand_subst(kernel) @@ -351,7 +363,7 @@ def temporary_to_subst(kernel, temp_name, within=None): kernel.substitutions, kernel.get_var_name_generator()) tts = TemporaryToSubstChanger(rule_mapping_context, temp_name, definition_insn_ids, - usage_to_definition, within) + usage_to_definition, extra_arguments, within) kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) @@ -379,7 +391,7 @@ def temporary_to_subst(kernel, temp_name, within=None): new_substs[subst_name] = SubstitutionRule( name=subst_name, - arguments=tuple(arguments), + arguments=tuple(arguments) + extra_arguments, expression=def_insn.expression) # }}} diff --git a/test/test_fortran.py b/test/test_fortran.py index 4e8de305a927f3c16c0902324491549915fc0ed9..d51411a6e65ba750ddf8cfc2ee844c4bc167ab82 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -130,7 +130,7 @@ def test_temporary_to_subst(ctx_factory): integer n do i = 1, n - a = inp(n) + a = inp(i) out(i) = 5*a out2(i) = 6*a end do @@ -142,7 +142,7 @@ def test_temporary_to_subst(ctx_factory): ref_knl = knl - knl = lp.temporary_to_subst(knl, "a") + knl = lp.temporary_to_subst(knl, "a", "i") ctx = ctx_factory() lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))