diff --git a/pymbolic/cse.py b/pymbolic/cse.py index b7c6109088db2e81f523a2f8f87467211cace539..34fe275102163ea7f0f888cfeb964f536fb3434b 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -110,14 +110,15 @@ class CSEMapper(IdentityMapper): def map_common_subexpression(self, expr): # Avoid creating CSE(CSE(...)) - # - # NOTE: This is not equivalent to isinstance--it's more specific, - # and for a reason! (Because we don't want to obliterate derived - # CSE types.) if type(expr) is prim.CommonSubexpression: return prim.wrap_in_cse(self.rec(expr.child), expr.prefix) else: - return IdentityMapper.map_common_subexpression(self, expr) + # expr is of a derived CSE type + result = self.rec(expr.child) + if type(result) is prim.CommonSubexpression: + result = result.child + + return type(expr)(result, expr.prefix, **expr.get_extra_properties()) def map_substitution(self, expr): return type(expr)(