diff --git a/pymbolic/interop/common.py b/pymbolic/interop/common.py index dd9177537bc95d09615c04022129c1d9066f908b..7501a5078f3e7c52badcef4ba36c543275426c0d 100644 --- a/pymbolic/interop/common.py +++ b/pymbolic/interop/common.py @@ -111,7 +111,7 @@ class SympyLikeToPymbolicMapper(SympyLikeMapperBase): return prim.CommonSubexpression( self.rec(expr.args[0]), expr.prefix) - def not_supported(self, expr): + def not_supported(self, expr): # noqa if isinstance(expr, int): return expr elif getattr(expr, "is_Function", False): diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index b4807712591e695a434662d4fc93b88f3322575d..1e31eade77fdd4a0faa0730319d312031f3cd57d 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -63,6 +63,12 @@ class SympyToPymbolicMapper(SympyLikeToPymbolicMapper): def map_long(self, expr): return long(expr) # noqa + def map_Indexed(self, expr): # noqa + return prim.Subscript( + self.rec(expr.args[0].args[0]), + tuple(self.rec(i) for i in expr.args[1:]) + ) + def map_Piecewise(self, expr): # noqa # We only handle piecewises with 2 arguments! assert len(expr.args) == 2 @@ -85,7 +91,6 @@ class SympyToPymbolicMapper(SympyLikeToPymbolicMapper): map_StrictGreaterThan = partial(_comparison_operator, operator=">") map_StrictLessThan = partial(_comparison_operator, operator="<") - # }}} @@ -103,6 +108,12 @@ class PymbolicToSympyMapper(PymbolicToSympyLikeMapper): return self.sym.Derivative(self.rec(expr.child), *[self.sym.Symbol(v) for v in expr.variables]) + def map_subscript(self, expr): + return self.sym.tensor.indexed.Indexed( + self.rec(expr.aggregate), + *tuple(self.rec(i) for i in expr.index_tuple) + ) + def map_if(self, expr): cond = self.rec(expr.condition) return self.sym.Piecewise((self.rec(expr.then), cond), diff --git a/test/test_sympy.py b/test/test_sympy.py index e0429bb3eaf327725a1846b3450f4a41391530c7..d248fe468e58b57e48b6b3a22837d7d4b99a55a1 100644 --- a/test/test_sympy.py +++ b/test/test_sympy.py @@ -25,7 +25,7 @@ THE SOFTWARE. import pytest import pymbolic.primitives as prim -x_, y_ = (prim.Variable(s) for s in "x y".split()) +x_, y_, i_, j_ = (prim.Variable(s) for s in "x y i j".split()) # {{{ to pymbolic test @@ -51,6 +51,11 @@ def _test_to_pymbolic(mapper, sym, use_symengine): assert mapper(sym.Function("f")(x)) == prim.Variable("f")(x_) assert mapper(sym.exp(x)) == prim.Variable("exp")(x_) + # indexed accesses + if not use_symengine: + i, j = sym.symbols("i,j") + assert mapper(sym.tensor.indexed.Indexed(x, i, j)) == x_[i_, j_] + # constants import math # FIXME: Why isn't this exact? @@ -91,7 +96,11 @@ def _test_from_pymbolic(mapper, sym, use_symengine): deriv = sym.Derivative(x**2, x) assert mapper(prim.Derivative(x_**2, ("x",))) == deriv - assert mapper(x_[0]) == sym.Symbol("x_0") + if use_symengine: + assert mapper(x_[0]) == sym.Symbol("x_0") + else: + i, j = sym.symbols("i,j") + assert mapper(x_[i_, j_]) == sym.tensor.indexed.Indexed(x, i, j) assert mapper(prim.Variable("f")(x_)) == sym.Function("f")(x)