From fe6ac8417e876d2640430afc6113dfb7f8968fb1 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 8 Dec 2008 21:01:04 -0500 Subject: [PATCH] Amputate lots of handle_unsupported_expression() routines. --- src/__init__.py | 2 -- src/compiler.py | 5 +++-- src/mapper/collector.py | 8 -------- src/mapper/constant_folder.py | 11 ++++++----- src/mapper/dependency.py | 24 ------------------------ src/mapper/evaluator.py | 8 +------- src/mapper/expander.py | 2 -- src/mapper/flattener.py | 2 -- src/mapper/flop_counter.py | 3 --- 9 files changed, 10 insertions(+), 55 deletions(-) diff --git a/src/__init__.py b/src/__init__.py index a7c69c0..de5ad50 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -23,8 +23,6 @@ parse = pymbolic.parser.parse evaluate = pymbolic.mapper.evaluator.evaluate evaluate_kw = pymbolic.mapper.evaluator.evaluate_kw compile = pymbolic.compiler.compile -is_constant = pymbolic.mapper.dependency.is_constant -get_dependencies = pymbolic.mapper.dependency.get_dependencies substitute = pymbolic.mapper.substitutor.substitute differentiate = pymbolic.mapper.differentiator.differentiate expand = pymbolic.mapper.expander.expand diff --git a/src/compiler.py b/src/compiler.py index 90974f7..db5fec3 100644 --- a/src/compiler.py +++ b/src/compiler.py @@ -60,8 +60,9 @@ class CompiledExpression: def __compile__(self): ctx = self.context() - used_variables = pymbolic.get_dependencies(self._Expression, - composite_leaves=False) + from pymbolic.mapper.dependency import DependencyMapper + used_variables = DependencyMapper( + composite_leaves=False)(self._Expression) used_variables -= set(self._Variables) used_variables -= set(pymbolic.var(key) for key in ctx.keys()) used_variables = list(used_variables) diff --git a/src/mapper/collector.py b/src/mapper/collector.py index bc1bc53..2dfc8d5 100644 --- a/src/mapper/collector.py +++ b/src/mapper/collector.py @@ -78,11 +78,3 @@ class TermCollector(IdentityMapper): result = pymbolic.flattened_sum(coeff*rep2term(termrep) for termrep, coeff in term2coeff.iteritems()) return result - - def handle_unsupported_expression(self, expr): - from pymbolic.primitives import AlgebraicLeaf - if isinstance(expr, AlgebraicLeaf): - return expr - else: - IdentityMapper.handle_unsupported_expression(self, expr) - diff --git a/src/mapper/constant_folder.py b/src/mapper/constant_folder.py index 5da8218..c9fa198 100644 --- a/src/mapper/constant_folder.py +++ b/src/mapper/constant_folder.py @@ -4,8 +4,12 @@ from pymbolic.mapper import IdentityMapper, NonrecursiveIdentityMapper class ConstantFoldingMapperBase(object): + def is_constant(self, expr): + from pymbolic.mapper.dependency import DependencyMapper + return not bool(DependencyMapper()(expr)) + def fold(self, expr, klass, op, constructor): - from pymbolic import is_constant, evaluate + from pymbolic import evaluate constants = [] nonconstants = [] @@ -16,7 +20,7 @@ class ConstantFoldingMapperBase(object): if isinstance(child, klass): queue = list(child.children) + queue else: - if is_constant(child): + if self.is_constant(child): constants.append(evaluate(child)) else: nonconstants.append(child) @@ -34,9 +38,6 @@ class ConstantFoldingMapperBase(object): return self.fold(expr, Sum, operator.add, Sum) - def handle_unsupported_expression(self, expr): - return expr - class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): diff --git a/src/mapper/dependency.py b/src/mapper/dependency.py index aea10c4..396be4f 100644 --- a/src/mapper/dependency.py +++ b/src/mapper/dependency.py @@ -30,13 +30,6 @@ class DependencyMapper(CombineMapper): import operator return reduce(operator.or_, values, set()) - def handle_unsupported_expression(self, expr): - from pymbolic.primitives import AlgebraicLeaf - if isinstance(expr, AlgebraicLeaf): - return set([expr]) - else: - CombineMapper.handle_unsupported_expression(self, expr) - def map_constant(self, expr): return set() @@ -69,20 +62,3 @@ class DependencyMapper(CombineMapper): return set([expr]) else: return CombineMapper.map_common_subexpression(self, expr) - - - - - -def get_dependencies(expr, **kwargs): - return DependencyMapper(**kwargs)(expr) - - - - -def is_constant(expr, with_respect_to=None, **kwargs): - if not with_respect_to: - return not bool(get_dependencies(expr, **kwargs)) - else: - return not (set(with_respect_to) & get_dependencies(expr, **kwargs)) - diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index 98799c9..82900b8 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -72,7 +72,7 @@ class EvaluationMapper(RecursiveMapper): result[i] = self.rec(expr[i]) return result - def map_common_subexpression(self, expr, out=None): + def map_common_subexpression(self, expr): try: return self.common_subexp_cache[expr.child] except KeyError: @@ -83,12 +83,6 @@ class EvaluationMapper(RecursiveMapper): class FloatEvaluationMapper(EvaluationMapper): - def handle_unsupported_expression(self, expr): - try: - return float(expr) - except: - raise TypeError, "cannot convert %s to float" % type(expr) - def map_constant(self, expr): return float(expr) diff --git a/src/mapper/expander.py b/src/mapper/expander.py index 0822075..6978bc7 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -66,8 +66,6 @@ class ExpandMapper(IdentityMapper): else: return IdentitityMapper.map_power(expr) - def handle_unsupported_expression(self, expr): - return expr diff --git a/src/mapper/flattener.py b/src/mapper/flattener.py index ee05303..102ca90 100644 --- a/src/mapper/flattener.py +++ b/src/mapper/flattener.py @@ -12,8 +12,6 @@ class FlattenMapper(IdentityMapper): from pymbolic.primitives import flattened_product return flattened_product(self.rec(ch) for ch in expr.children) - def handle_unsupported_expression(self, expr): - return expr diff --git a/src/mapper/flop_counter.py b/src/mapper/flop_counter.py index 029f728..e590d94 100644 --- a/src/mapper/flop_counter.py +++ b/src/mapper/flop_counter.py @@ -7,9 +7,6 @@ class FlopCounter(CombineMapper): def combine(self, values): return sum(values) - def handle_unsupported_expression(self, expr, *args, **kwargs): - return 0 - def map_constant(self, expr): return 0 -- GitLab