diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 60473ada580fdca7b618f8e33086e0d7ea9f267a..f5cf07b0e1d62212ce36edb48f47eb7de7d31451 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -866,7 +866,8 @@ class RuleAwareIdentityMapper(IdentityMapper): if name not in self.rule_mapping_context.old_subst_rules: return super(RuleAwareIdentityMapper, self).map_call(expr, expn_state) else: - return self.map_substitution(name, tag, expr.parameters, expn_state) + return self.map_substitution(name, tag, self.rec( + expr.parameters, expn_state), expn_state) @staticmethod def make_new_arg_context(rule_name, arg_names, arguments, arg_context): diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index a681afe06520483c83530c241e39229412e88f03..b92698ffa1e84455be3f79bed7dbf884f36be490 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -469,6 +469,13 @@ def assignment_to_subst(kernel, lhs_name, extra_arguments=(), within=None, # {{{ expand_subst def expand_subst(kernel, within=None): + """ + Returns an instance of :class:`loopy.LoopKernel` with the substitutions + referenced in instructions of *kernel* matched by *within* expanded. + + :arg within: a stack match as understood by + :func:`loopy.match.parse_stack_match`. + """ if not kernel.substitutions: return kernel diff --git a/test/test_transform.py b/test/test_transform.py index 394cf668804ed719920e02bc3d20f62971421c2f..3ee67b703964d1f7773b10a9199687d78b883a60 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -550,6 +550,26 @@ def test_split_iname_only_if_in_within(): assert insn.within_inames == frozenset({'i'}) +def test_nested_substs_in_insns(ctx_factory): + ctx = ctx_factory() + import loopy as lp + + ref_knl = lp.make_kernel( + "{[i]: 0<=i<10}", + """ + a(x) := 2 * x + b(x) := x**2 + c(x) := 7 * x + f[i] = c(b(a(i))) + """ + ) + + knl = lp.expand_subst(ref_knl) + assert not knl.substitutions + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])