diff --git a/src/mapper/expander.py b/src/mapper/expander.py index a93d562db93d119ad07fec448266b0323262efed..222c0368d8e0e6e6347b4c56b9f1806eec451474 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -1,6 +1,8 @@ import pymbolic from pymbolic.mapper import IdentityMapper -from pymbolic.primitives import Sum, Product, Power, AlgebraicLeaf +from pymbolic.primitives import \ + Sum, Product, Power, AlgebraicLeaf, \ + is_constant @@ -40,6 +42,8 @@ class CommutativeTermCollector(object): terms = mul_term.children elif isinstance(mul_term, (Power, AlgebraicLeaf)): terms = [mul_term] + elif is_constant(mul_term): + terms = [mul_term] else: raise RuntimeError, "split_term expects a multiplicative term" @@ -129,13 +133,13 @@ class ExpandMapper(IdentityMapper): from pymbolic.primitives import Expression, Sum if isinstance(expr.base, Product): - return self.rec(pymbolic.product( + return self.rec(pymbolic.flattened_product( child**expr.exponent for child in newbase)) if isinstance(expr.exponent, int): newbase = self.rec(expr.base) if isinstance(newbase, Sum): - return self.map_product(pymbolic.product(expr.exponent*(newbase,))) + return self.map_product(pymbolic.flattened_product(expr.exponent*(newbase,))) else: return IdentitityMapper.map_power(expr) else: diff --git a/src/primitives.py b/src/primitives.py index 6797572c6ef4f67c340211ed8bdee22a180201f8..10cbbe28b1973e993d73d1b6b410ac41641f6b46 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -175,7 +175,7 @@ class Variable(Leaf): self.name = name def __getinitargs__(self): - return self._Name, + return self.name, def __lt__(self, other): if isinstance(other, Variable): diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..4b14c56e3012f0dcf471ec3acc88ab2f88479335 --- /dev/null +++ b/test/test_pymbolic.py @@ -0,0 +1,16 @@ +import unittest + + + + +class TestPymbolic(unittest.TestCase): + def test_expand(self): + from pymbolic import var, expand + + x = var("x") + u = (x+1)**5 + print expand(u) + + +if __name__ == '__main__': + unittest.main()