diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 56cd896eb2ba05fc6e264387e7553f7764bc05c9..63a564eaf43f01648ae4d21ab30ee54cc2cb2a68 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -625,6 +625,8 @@ class ParametrizedSubstitutor(IdentityMapper): raise RuntimeError("CSE '%s' must be invoked with %d arguments" % (expr.name, len(arg_names))) + cse_expr = self.rec(cse_expr) + if self.wrap_cse: from pymbolic.primitives import CommonSubexpression return CommonSubexpression(cse_expr, expr.name) @@ -647,7 +649,7 @@ class ParametrizedSubstitutor(IdentityMapper): subst_map = SubstitutionMapper(make_subst_func( dict(zip(arg_names, expr.parameters)))) - cse_expr = subst_map(cse_expr) + cse_expr = self.rec(subst_map(cse_expr)) if self.wrap_cse: return CommonSubexpression(cse_expr, cse_name)