diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 165607a0527d36bf37565c13f3cc8bf9d249031c..124984ea3584a86732dc4067d7452210341a0e2a 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -24,12 +24,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import numpy as np -from pymbolic.mapper import CSECachingMapperMixin, Collector +from pymbolic.mapper import CSECachingMapperMixin from loopy.tools import intern_frozenset_of_ids -from loopy.symbolic import IdentityMapper, WalkMapper +from loopy.symbolic import IdentityMapper, WalkMapper, CombineMapper from loopy.kernel.data import ( InstructionBase, MultiAssignmentBase, Assignment, @@ -43,6 +42,8 @@ from six.moves import range, zip, intern import re +from functools import reduce + import logging logger = logging.getLogger(__name__) @@ -1880,16 +1881,22 @@ class FunctionScoper(IdentityMapper): return IdentityMapper.map_call(self, expr) -class ScopedFunctionCollector(Collector): +class ScopedFunctionCollector(CombineMapper): """ This mapper would collect all the instances of :class:`ScopedFunction` occurring in the expression and written all of them as a :class:`set`. """ + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) def map_scoped_function(self, expr): - return set([expr.name]) + return frozenset([expr.name]) - def map_sub_array_ref(self, expr): - return set() + def map_constant(self, expr): + return frozenset() + + map_variable = map_constant + map_function_symbol = map_constant def scope_functions(kernel): diff --git a/loopy/preprocess.py b/loopy/preprocess.py index eedfca6f91ad890d1defb189778080279ebb6613..e7472ddd6d059ca19b6d48c1ad84a00a2976f376 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2105,12 +2105,36 @@ def check_atomic_loads(kernel): # {{{ check for unscoped calls -class UnScopedCallCollector(Collector): +class UnScopedCallCollector(CombineMapper): + + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) + def map_call(self, expr): if not isinstance(expr.function, ScopedFunction): - return set([expr.function.name]) + return (frozenset([expr.function.name]) | + self.combine((self.rec(child) for child in expr.parameters))) + else: + return self.combine((self.rec(child) for child in expr.parameters)) + + def map_call_with_kwargs(self, expr): + if not isinstance(expr.function, ScopedFunction): + return (frozenset([expr.function.name]) | + self.combine((self.rec(child) for child in expr.parameters + + expr.kw_parameter.values()))) else: - return set() + return self.combine((self.rec(child) for child in + expr.parameters+expr.kw_parameters.values())) + + def map_scoped_function(self, expr): + return frozenset([expr.name]) + + def map_constant(self, expr): + return frozenset() + + map_variable = map_constant + map_function_symbol = map_constant def check_functions_are_scoped(kernel):