From 8e955fed5ad64896c7afe2d136bc01cf5f3e0801 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 9 Dec 2008 02:54:33 -0500 Subject: [PATCH] Extract TermCollector's dep on a DependencyMapper into an overridable method. --- src/mapper/collector.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/mapper/collector.py b/src/mapper/collector.py index 2dfc8d5..da0fbd5 100644 --- a/src/mapper/collector.py +++ b/src/mapper/collector.py @@ -11,6 +11,10 @@ class TermCollector(IdentityMapper): def __init__(self, parameters=set()): self.parameters = parameters + def get_dependencies(self, expr): + from pymbolic.mapper.dependency import DependencyMapper + return DependencyMapper()(expr) + def split_term(self, mul_term): """Returns a pair consisting of: - a frozenset of (base, exponent) pairs @@ -20,7 +24,6 @@ class TermCollector(IdentityMapper): The argument `product' has to be fully expanded already. """ - from pymbolic import get_dependencies, is_constant from pymbolic.primitives import Product, Power, AlgebraicLeaf def base(term): @@ -39,7 +42,7 @@ class TermCollector(IdentityMapper): terms = mul_term.children elif isinstance(mul_term, (Power, AlgebraicLeaf)): terms = [mul_term] - elif is_constant(mul_term): + elif not bool(self.get_dependencies(mul_term)): terms = [mul_term] else: raise RuntimeError, "split_term expects a multiplicative term" @@ -58,7 +61,7 @@ class TermCollector(IdentityMapper): cleaned_base2exp = {} for base, exp in base2exp.iteritems(): term = base**exp - if get_dependencies(term) <= self.parameters: + if self.get_dependencies(term) <= self.parameters: coefficients.append(term) else: cleaned_base2exp[base] = exp -- GitLab