diff --git a/pymbolic/cse.py b/pymbolic/cse.py index 7ebaae61ef7735d3d029474c09f0f3e672d60581..700479a06ee18990e2fb0ed70cd644b2a0c1e799 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -69,7 +69,10 @@ class CSEMapper(IdentityMapper): def map_common_subexpression(self, expr): # don't duplicate CSEs - return prim.wrap_in_cse(self.rec(expr.child), expr.prefix) + if type(expr) is prim.CommonSubexpression: + return prim.wrap_in_cse(self.rec(expr.child), expr.prefix) + else: + return IdentityMapper.map_common_subexpression(self, expr) def map_substitution(self, expr): return type(expr)(