diff --git a/pymbolic/interop/maxima.py b/pymbolic/interop/maxima.py index 2dabf2bcc3fdb32373d87c1adea5d22ec5b94b7f..e2c624c5f093188e8103ead233d074e090c46c32 100644 --- a/pymbolic/interop/maxima.py +++ b/pymbolic/interop/maxima.py @@ -40,6 +40,7 @@ import six from six.moves import intern import re import pytools +import numpy as np from pymbolic.mapper.stringifier import StringifyMapper from pymbolic.parser import Parser as ParserBase, FinalizedTuple @@ -92,10 +93,12 @@ class MaximaStringifyMapper(StringifyMapper): class MaximaParser(ParserBase): power_sym = intern("power") imag_unit = intern("imag_unit") + euler_number = intern("euler_number") lex_table = [ (power_sym, pytools.lex.RE(r"\^")), (imag_unit, pytools.lex.RE(r"%i")), + (euler_number, pytools.lex.RE(r"%e")), ] + ParserBase.lex_table def parse_prefix(self, pstate): @@ -126,6 +129,9 @@ class MaximaParser(ParserBase): elif next_tag is self.imag_unit: pstate.advance() return 1j + elif next_tag is self.euler_number: + pstate.advance() + return np.e elif next_tag is p._identifier: if six.PY3: return primitives.Variable(pstate.next_str_and_advance()) @@ -158,7 +164,6 @@ class MaximaParser(ParserBase): pstate.advance() if left_exp == primitives.Variable("matrix"): - import numpy as np left_exp = np.array(list(list(row) for row in args)) else: left_exp = primitives.Call(left_exp, args) @@ -197,7 +202,12 @@ class MaximaParser(ParserBase): did_something = True elif next_tag is self.power_sym and p._PREC_POWER > min_precedence: pstate.advance() - left_exp **= self.parse_expression(pstate, p._PREC_POWER) + exponent = self.parse_expression(pstate, p._PREC_POWER) + if left_exp == np.e: + from pymbolic.primitives import Call, Variable + left_exp = Call(Variable("exp"), (exponent,)) + else: + left_exp **= exponent did_something = True elif next_tag is p._comma and p._PREC_COMMA > min_precedence: # The precedence makes the comma left-associative. diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index 80fa435d98632766442c2b40be466620c31cfaa9..0d7cd128f3f499cad01d5207c3f2b6c5e57bbdf4 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -29,6 +29,8 @@ from pymbolic.interop.common import ( SympyLikeToPymbolicMapper, PymbolicToSympyLikeMapper) import pymbolic.primitives as prim +from functools import partial + import sympy @@ -67,6 +69,28 @@ class SympyToPymbolicMapper(SympyLikeToPymbolicMapper): tuple(self.rec(i) for i in expr.args[1:]) ) + def map_Piecewise(self, expr): # noqa + # We only handle piecewises with 2 arguments! + assert len(expr.args) == 2 + # We only handle if/else cases + assert expr.args[1][1].is_Boolean and bool(expr.args[1][1]) is True + then = self.rec(expr.args[0][0]) + else_ = self.rec(expr.args[1][0]) + cond = self.rec(expr.args[0][1]) + return prim.If(cond, then, else_) + + def _comparison_operator(self, expr, operator=None): + left = self.rec(expr.args[0]) + right = self.rec(expr.args[1]) + return prim.Comparison(left, operator, right) + + map_Equality = partial(_comparison_operator, operator="==") + map_Unequality = partial(_comparison_operator, operator="!=") + map_GreaterThan = partial(_comparison_operator, operator=">=") + map_LessThan = partial(_comparison_operator, operator="<=") + map_StrictGreaterThan = partial(_comparison_operator, operator=">") + map_StrictLessThan = partial(_comparison_operator, operator="<") + # }}} @@ -84,11 +108,37 @@ class PymbolicToSympyMapper(PymbolicToSympyLikeMapper): return self.sym.Derivative(self.rec(expr.child), *[self.sym.Symbol(v) for v in expr.variables]) +<<<<<<< HEAD def map_subscript(self, expr): return self.sym.tensor.indexed.Indexed( self.rec(expr.aggregate), *tuple(self.rec(i) for i in expr.index_tuple) ) +======= + def map_if(self, expr): + cond = self.rec(expr.condition) + return self.sym.Piecewise((self.rec(expr.then), cond), + (self.rec(expr.else_), True) + ) + + def map_comparison(self, expr): + left = self.rec(expr.left) + right = self.rec(expr.right) + if expr.operator == "==": + return self.sym.Equality(left, right) + elif expr.operator == "!=": + return self.sym.Unequality(left, right) + elif expr.operator == "<": + return self.sym.StrictLessThan(left, right) + elif expr.operator == ">": + return self.sym.StrictGreaterThan(left, right) + elif expr.operator == "<=": + return self.sym.LessThan(left, right) + elif expr.operator == ">=": + return self.sym.GreaterThan(left, right) + else: + raise NotImplementedError("Unknown operator '%s'" % expr.operator) +>>>>>>> c860c0d5a0abcf18b34914b2ccbb0a22c9b98d3f # }}} diff --git a/test/test_maxima.py b/test/test_maxima.py index 7ddbb687dd711b2ff83cadaf8d81c72c0fed8a42..7bd2ad6b640be092d6d3ef3ba21ae5f2bac33948 100644 --- a/test/test_maxima.py +++ b/test/test_maxima.py @@ -74,6 +74,7 @@ def test_strict_round_trip(knl): 2j, parse("x**y"), Quotient(1, 2), + parse("exp(x)") ] for expr in exprs: result = knl.eval_expr(expr) diff --git a/test/test_sympy.py b/test/test_sympy.py index c3b038e4b5ccca7ec30544bc9ae74db01ab557fb..d248fe468e58b57e48b6b3a22837d7d4b99a55a1 100644 --- a/test/test_sympy.py +++ b/test/test_sympy.py @@ -123,6 +123,17 @@ def test_pymbolic_to_sympy(): _test_from_pymbolic(mapper, sym, False) +def test_sympy_if_condition(): + pytest.importorskip("sympy") + from pymbolic.interop.sympy import PymbolicToSympyMapper, SympyToPymbolicMapper + forward = PymbolicToSympyMapper() + backward = SympyToPymbolicMapper() + + # Test round trip to sympy and back + expr = prim.If(prim.Comparison(x_, "<=", y_), 1, 0) + assert backward(forward(expr)) == expr + + if __name__ == "__main__": import sys if len(sys.argv) > 1: