diff --git a/src/mapper/collector.py b/src/mapper/collector.py index 2dfc8d554846db139523b8b4e7798e47945899eb..da0fbd594785dc5a144230b566158d008b58e560 100644 --- a/src/mapper/collector.py +++ b/src/mapper/collector.py @@ -11,6 +11,10 @@ class TermCollector(IdentityMapper): def __init__(self, parameters=set()): self.parameters = parameters + def get_dependencies(self, expr): + from pymbolic.mapper.dependency import DependencyMapper + return DependencyMapper()(expr) + def split_term(self, mul_term): """Returns a pair consisting of: - a frozenset of (base, exponent) pairs @@ -20,7 +24,6 @@ class TermCollector(IdentityMapper): The argument `product' has to be fully expanded already. """ - from pymbolic import get_dependencies, is_constant from pymbolic.primitives import Product, Power, AlgebraicLeaf def base(term): @@ -39,7 +42,7 @@ class TermCollector(IdentityMapper): terms = mul_term.children elif isinstance(mul_term, (Power, AlgebraicLeaf)): terms = [mul_term] - elif is_constant(mul_term): + elif not bool(self.get_dependencies(mul_term)): terms = [mul_term] else: raise RuntimeError, "split_term expects a multiplicative term" @@ -58,7 +61,7 @@ class TermCollector(IdentityMapper): cleaned_base2exp = {} for base, exp in base2exp.iteritems(): term = base**exp - if get_dependencies(term) <= self.parameters: + if self.get_dependencies(term) <= self.parameters: coefficients.append(term) else: cleaned_base2exp[base] = exp