diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 5bfa929b1ee6f1cf4c418259c735521be8da44f9..a009441d86b454c33797ff20290d875350b3ebd1 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -217,20 +217,6 @@ class StringifyMapper(StringifyMapperBase): enclosing_prec, PREC_CALL) -class DependencyMapper(DependencyMapperBase): - def map_reduction(self, expr): - return (self.rec(expr.expr) - - set(Variable(iname) for iname in expr.inames)) - - def map_tagged_variable(self, expr): - return set([expr]) - - def map_loopy_function_identifier(self, expr): - return set() - - map_linear_subscript = DependencyMapperBase.map_subscript - - class UnidirectionalUnifier(UnidirectionalUnifierBase): def map_reduction(self, expr, other, unis): if not isinstance(other, type(expr)): @@ -257,9 +243,37 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): from pymbolic.mapper.unifier import unify_many return unify_many(urecs, new_uni_record) + +class DependencyMapper(DependencyMapperBase): + def map_call(self, expr, *args): + # Loopy does not have first-class functions. Do not descend + # into 'function' attribute of Call. + return self.combine( + self.rec(child, *args) for child in expr.parameters) + + def map_reduction(self, expr): + return (self.rec(expr.expr) + - set(Variable(iname) for iname in expr.inames)) + + def map_tagged_variable(self, expr): + return set([expr]) + + def map_loopy_function_identifier(self, expr): + return set() + + map_linear_subscript = DependencyMapperBase.map_subscript + # }}} +@memoize +def get_dependencies(expr): + from loopy.symbolic import DependencyMapper + dep_mapper = DependencyMapper(composite_leaves=False) + + return frozenset(dep.name for dep in dep_mapper(expr)) + + # {{{ identity mapper that expands subst rules on the fly def parse_tagged_name(expr): @@ -915,14 +929,6 @@ class PrimeAdder(IdentityMapper): # }}} -@memoize -def get_dependencies(expr): - from loopy.symbolic import DependencyMapper - dep_mapper = DependencyMapper(composite_leaves=False) - - return frozenset(dep.name for dep in dep_mapper(expr)) - - # {{{ get access range def get_access_range(domain, subscript):