From 8ff29d787095ffc697d406f7e7eb28eb54be5cb9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 4 Oct 2007 19:10:42 -0400 Subject: [PATCH] Fix a bunch of bugs. (see below) Add a (1-testcase) test suite. Fix leftover from _Name -> name transition in Variable. Change a few pymbolic.product to pymbolic.flattened_product in ExpandMapper. Let ExpandMapper handle constants properly. --- src/mapper/expander.py | 10 +++++++--- src/primitives.py | 2 +- test/test_pymbolic.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) create mode 100644 test/test_pymbolic.py diff --git a/src/mapper/expander.py b/src/mapper/expander.py index a93d562..222c036 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -1,6 +1,8 @@ import pymbolic from pymbolic.mapper import IdentityMapper -from pymbolic.primitives import Sum, Product, Power, AlgebraicLeaf +from pymbolic.primitives import \ + Sum, Product, Power, AlgebraicLeaf, \ + is_constant @@ -40,6 +42,8 @@ class CommutativeTermCollector(object): terms = mul_term.children elif isinstance(mul_term, (Power, AlgebraicLeaf)): terms = [mul_term] + elif is_constant(mul_term): + terms = [mul_term] else: raise RuntimeError, "split_term expects a multiplicative term" @@ -129,13 +133,13 @@ class ExpandMapper(IdentityMapper): from pymbolic.primitives import Expression, Sum if isinstance(expr.base, Product): - return self.rec(pymbolic.product( + return self.rec(pymbolic.flattened_product( child**expr.exponent for child in newbase)) if isinstance(expr.exponent, int): newbase = self.rec(expr.base) if isinstance(newbase, Sum): - return self.map_product(pymbolic.product(expr.exponent*(newbase,))) + return self.map_product(pymbolic.flattened_product(expr.exponent*(newbase,))) else: return IdentitityMapper.map_power(expr) else: diff --git a/src/primitives.py b/src/primitives.py index 6797572..10cbbe2 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -175,7 +175,7 @@ class Variable(Leaf): self.name = name def __getinitargs__(self): - return self._Name, + return self.name, def __lt__(self, other): if isinstance(other, Variable): diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py new file mode 100644 index 0000000..4b14c56 --- /dev/null +++ b/test/test_pymbolic.py @@ -0,0 +1,16 @@ +import unittest + + + + +class TestPymbolic(unittest.TestCase): + def test_expand(self): + from pymbolic import var, expand + + x = var("x") + u = (x+1)**5 + print expand(u) + + +if __name__ == '__main__': + unittest.main() -- GitLab