diff --git a/pymbolic/sympy_conv.py b/pymbolic/sympy_conv.py index 049131c9da35e81f9250c60419cfe9d264f24c6e..079a14acde9d80a88d96019d0484b7f7fddad9ee 100644 --- a/pymbolic/sympy_conv.py +++ b/pymbolic/sympy_conv.py @@ -34,13 +34,32 @@ class _SympyMapper(object): -class ToPymbolicMapper(_SympyMapper): +class CSE(sp.Function): + """A function to translate to a Pymbolic CSE.""" + + nargs = 1 + + + + +def make_cse(arg, prefix=None): + result = CSE(arg) + result.prefix = prefix + return result + + + + +class SympyToPymbolicMapper(_SympyMapper): def map_Symbol(self, expr): return prim.Variable(expr.name) def map_ImaginaryUnit(self, expr): return 1j + def map_Pi(self, expr): + return float(expr) + def map_Add(self, expr): return prim.Sum(tuple(self.rec(arg) for arg in expr.args)) @@ -68,6 +87,10 @@ class ToPymbolicMapper(_SympyMapper): return prim.Derivative(self.rec(expr.expr), tuple(v.name for v in expr.variables)) + def map_CSE(self, expr): + return prim.CommonSubexpression( + self.rec(expr.args[0]), expr.prefix) + def not_supported(self, expr): if isinstance(expr, int): return expr