diff --git a/src/__init__.py b/src/__init__.py index a7c69c04c36e665b9034b3df13d97af37483e8ac..de5ad50958a7447a9e4f8376b7c210cce4e8a2f2 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -23,8 +23,6 @@ parse = pymbolic.parser.parse evaluate = pymbolic.mapper.evaluator.evaluate evaluate_kw = pymbolic.mapper.evaluator.evaluate_kw compile = pymbolic.compiler.compile -is_constant = pymbolic.mapper.dependency.is_constant -get_dependencies = pymbolic.mapper.dependency.get_dependencies substitute = pymbolic.mapper.substitutor.substitute differentiate = pymbolic.mapper.differentiator.differentiate expand = pymbolic.mapper.expander.expand diff --git a/src/compiler.py b/src/compiler.py index 90974f7c85b4bfbaf6ea276f4e7c9ddc8e9d2062..db5fec38884504455b5bd7731c325f265fc2bd83 100644 --- a/src/compiler.py +++ b/src/compiler.py @@ -60,8 +60,9 @@ class CompiledExpression: def __compile__(self): ctx = self.context() - used_variables = pymbolic.get_dependencies(self._Expression, - composite_leaves=False) + from pymbolic.mapper.dependency import DependencyMapper + used_variables = DependencyMapper( + composite_leaves=False)(self._Expression) used_variables -= set(self._Variables) used_variables -= set(pymbolic.var(key) for key in ctx.keys()) used_variables = list(used_variables) diff --git a/src/mapper/collector.py b/src/mapper/collector.py index bc1bc5344adb16082441b92e417fa32e8c6457d1..2dfc8d554846db139523b8b4e7798e47945899eb 100644 --- a/src/mapper/collector.py +++ b/src/mapper/collector.py @@ -78,11 +78,3 @@ class TermCollector(IdentityMapper): result = pymbolic.flattened_sum(coeff*rep2term(termrep) for termrep, coeff in term2coeff.iteritems()) return result - - def handle_unsupported_expression(self, expr): - from pymbolic.primitives import AlgebraicLeaf - if isinstance(expr, AlgebraicLeaf): - return expr - else: - IdentityMapper.handle_unsupported_expression(self, expr) - diff --git a/src/mapper/constant_folder.py b/src/mapper/constant_folder.py index 5da8218bbb1b1a54f6a47b24f881c0e8f13c4549..c9fa19834922f7f5d25df271ccca9277cf574d3d 100644 --- a/src/mapper/constant_folder.py +++ b/src/mapper/constant_folder.py @@ -4,8 +4,12 @@ from pymbolic.mapper import IdentityMapper, NonrecursiveIdentityMapper class ConstantFoldingMapperBase(object): + def is_constant(self, expr): + from pymbolic.mapper.dependency import DependencyMapper + return not bool(DependencyMapper()(expr)) + def fold(self, expr, klass, op, constructor): - from pymbolic import is_constant, evaluate + from pymbolic import evaluate constants = [] nonconstants = [] @@ -16,7 +20,7 @@ class ConstantFoldingMapperBase(object): if isinstance(child, klass): queue = list(child.children) + queue else: - if is_constant(child): + if self.is_constant(child): constants.append(evaluate(child)) else: nonconstants.append(child) @@ -34,9 +38,6 @@ class ConstantFoldingMapperBase(object): return self.fold(expr, Sum, operator.add, Sum) - def handle_unsupported_expression(self, expr): - return expr - class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): diff --git a/src/mapper/dependency.py b/src/mapper/dependency.py index aea10c4f3ac0463baa1c46f26f96298727feaf04..396be4f38c3e2fa84837a54f611bcf6e476647eb 100644 --- a/src/mapper/dependency.py +++ b/src/mapper/dependency.py @@ -30,13 +30,6 @@ class DependencyMapper(CombineMapper): import operator return reduce(operator.or_, values, set()) - def handle_unsupported_expression(self, expr): - from pymbolic.primitives import AlgebraicLeaf - if isinstance(expr, AlgebraicLeaf): - return set([expr]) - else: - CombineMapper.handle_unsupported_expression(self, expr) - def map_constant(self, expr): return set() @@ -69,20 +62,3 @@ class DependencyMapper(CombineMapper): return set([expr]) else: return CombineMapper.map_common_subexpression(self, expr) - - - - - -def get_dependencies(expr, **kwargs): - return DependencyMapper(**kwargs)(expr) - - - - -def is_constant(expr, with_respect_to=None, **kwargs): - if not with_respect_to: - return not bool(get_dependencies(expr, **kwargs)) - else: - return not (set(with_respect_to) & get_dependencies(expr, **kwargs)) - diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index 98799c952c925e7b8eae0336135b39b50252bf7f..82900b829020cb23a385a2b512f71fde514a1a05 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -72,7 +72,7 @@ class EvaluationMapper(RecursiveMapper): result[i] = self.rec(expr[i]) return result - def map_common_subexpression(self, expr, out=None): + def map_common_subexpression(self, expr): try: return self.common_subexp_cache[expr.child] except KeyError: @@ -83,12 +83,6 @@ class EvaluationMapper(RecursiveMapper): class FloatEvaluationMapper(EvaluationMapper): - def handle_unsupported_expression(self, expr): - try: - return float(expr) - except: - raise TypeError, "cannot convert %s to float" % type(expr) - def map_constant(self, expr): return float(expr) diff --git a/src/mapper/expander.py b/src/mapper/expander.py index 0822075333fa286a1c9792f5e29fc5acfa688306..6978bc76b20765c63447bbcd2f8a9eef539fe2d9 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -66,8 +66,6 @@ class ExpandMapper(IdentityMapper): else: return IdentitityMapper.map_power(expr) - def handle_unsupported_expression(self, expr): - return expr diff --git a/src/mapper/flattener.py b/src/mapper/flattener.py index ee05303f2ad00476be9c95e4c689795565e69621..102ca90faebcabcf3b4c0ec172acf9e3b1055cfc 100644 --- a/src/mapper/flattener.py +++ b/src/mapper/flattener.py @@ -12,8 +12,6 @@ class FlattenMapper(IdentityMapper): from pymbolic.primitives import flattened_product return flattened_product(self.rec(ch) for ch in expr.children) - def handle_unsupported_expression(self, expr): - return expr diff --git a/src/mapper/flop_counter.py b/src/mapper/flop_counter.py index 029f7282f576870763dcc40f250361cb54ba73f8..e590d94f1acadacbc132b73bd46e39bc3651786c 100644 --- a/src/mapper/flop_counter.py +++ b/src/mapper/flop_counter.py @@ -7,9 +7,6 @@ class FlopCounter(CombineMapper): def combine(self, values): return sum(values) - def handle_unsupported_expression(self, expr, *args, **kwargs): - return 0 - def map_constant(self, expr): return 0