diff --git a/src/mapper/collector.py b/src/mapper/collector.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfc8d554846db139523b8b4e7798e47945899eb --- /dev/null +++ b/src/mapper/collector.py @@ -0,0 +1,80 @@ +import pymbolic +from pymbolic.mapper import IdentityMapper + + + + +class TermCollector(IdentityMapper): + """A term collector that assumes that multiplication is commutative. + """ + + def __init__(self, parameters=set()): + self.parameters = parameters + + def split_term(self, mul_term): + """Returns a pair consisting of: + - a frozenset of (base, exponent) pairs + - a product of coefficients (i.e. constants and parameters) + + The set takes care of order-invariant comparison for us and is hashable. + + The argument `product' has to be fully expanded already. + """ + from pymbolic import get_dependencies, is_constant + from pymbolic.primitives import Product, Power, AlgebraicLeaf + + def base(term): + if isinstance(term, Power): + return term.base + else: + return term + + def exponent(term): + if isinstance(term, Power): + return term.exponent + else: + return 1 + + if isinstance(mul_term, Product): + terms = mul_term.children + elif isinstance(mul_term, (Power, AlgebraicLeaf)): + terms = [mul_term] + elif is_constant(mul_term): + terms = [mul_term] + else: + raise RuntimeError, "split_term expects a multiplicative term" + + base2exp = {} + for term in terms: + mybase = base(term) + myexp = exponent(term) + + if mybase in base2exp: + base2exp[mybase] += myexp + else: + base2exp[mybase] = myexp + + coefficients = [] + cleaned_base2exp = {} + for base, exp in base2exp.iteritems(): + term = base**exp + if get_dependencies(term) <= self.parameters: + coefficients.append(term) + else: + cleaned_base2exp[base] = exp + + term = frozenset((base,exp) for base, exp in cleaned_base2exp.iteritems()) + return term, pymbolic.flattened_product(coefficients) + + def map_sum(self, mysum): + term2coeff = {} + for child in mysum.children: + term, coeff = self.split_term(child) + term2coeff[term] = term2coeff.get(term, 0) + coeff + + def rep2term(rep): + return pymbolic.flattened_product(base**exp for base, exp in rep) + + result = pymbolic.flattened_sum(coeff*rep2term(termrep) + for termrep, coeff in term2coeff.iteritems()) + return result diff --git a/src/mapper/expander.py b/src/mapper/expander.py index 222c0368d8e0e6e6347b4c56b9f1806eec451474..6978bc76b20765c63447bbcd2f8a9eef539fe2d9 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -1,5 +1,6 @@ import pymbolic from pymbolic.mapper import IdentityMapper +from pymbolic.mapper.collector import TermCollector from pymbolic.primitives import \ Sum, Product, Power, AlgebraicLeaf, \ is_constant @@ -7,89 +8,9 @@ from pymbolic.primitives import \ -class CommutativeTermCollector(object): - """A term collector that assumes that multiplication is commutative. - """ - - def __init__(self, parameters=set()): - self.Parameters = parameters - - def split_term(self, mul_term): - """Returns a pair consisting of: - - a frozenset of (base, exponent) pairs - - a product of coefficients (i.e. constants and parameters) - - The set takes care of order-invariant comparison for us and is hashable. - - The argument `product' has to be fully expanded already. - """ - from pymbolic import get_dependencies - - - def base(term): - if isinstance(term, Power): - return term.base - else: - return term - - def exponent(term): - if isinstance(term, Power): - return term.exponent - else: - return 1 - - if isinstance(mul_term, Product): - terms = mul_term.children - elif isinstance(mul_term, (Power, AlgebraicLeaf)): - terms = [mul_term] - elif is_constant(mul_term): - terms = [mul_term] - else: - raise RuntimeError, "split_term expects a multiplicative term" - - base2exp = {} - for term in terms: - mybase = base(term) - myexp = exponent(term) - - if mybase in base2exp: - base2exp[mybase] += myexp - else: - base2exp[mybase] = myexp - - coefficients = [] - cleaned_base2exp = {} - for base, exp in base2exp.iteritems(): - term = base**exp - if get_dependencies(term) <= self.Parameters: - coefficients.append(term) - else: - cleaned_base2exp[base] = exp - - term = frozenset((base,exp) for base, exp in cleaned_base2exp.iteritems()) - return term, pymbolic.flattened_product(coefficients) - - def __call__(self, mysum): - assert isinstance(mysum, Sum) - - term2coeff = {} - for child in mysum.children: - term, coeff = self.split_term(child) - term2coeff[term] = term2coeff.get(term, 0) + coeff - - def rep2term(rep): - return pymbolic.flattened_product(base**exp for base, exp in rep) - - result = pymbolic.flattened_sum(coeff*rep2term(termrep) - for termrep, coeff in term2coeff.iteritems()) - return result - - - - class ExpandMapper(IdentityMapper): - def __init__(self, collector=CommutativeTermCollector()): - self.Collector = collector + def __init__(self, collector=TermCollector()): + self.collector = collector def map_product(self, expr): from pymbolic.primitives import Sum, Product @@ -118,7 +39,7 @@ class ExpandMapper(IdentityMapper): else: rest = 1 - result = self.Collector(pymbolic.flattened_sum( + result = self.collector(pymbolic.flattened_sum( pymbolic.flattened_product(leading) * expand(sumchild*rest) for sumchild in sum.children )) @@ -150,6 +71,6 @@ class ExpandMapper(IdentityMapper): def expand(expr, parameters=set(), commutative=True): if commutative: - return ExpandMapper(CommutativeTermCollector(parameters))(expr) + return ExpandMapper(TermCollector(parameters))(expr) else: return ExpandMapper(lambda x: x)(expr)