From a626687c655d697182349432b98fde82e87054fa Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 22 Mar 2018 17:07:33 -0500 Subject: [PATCH] Changed from collectors to combine mappers --- loopy/kernel/creation.py | 21 ++++++++++++++------- loopy/preprocess.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 165607a05..124984ea3 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -24,12 +24,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import numpy as np -from pymbolic.mapper import CSECachingMapperMixin, Collector +from pymbolic.mapper import CSECachingMapperMixin from loopy.tools import intern_frozenset_of_ids -from loopy.symbolic import IdentityMapper, WalkMapper +from loopy.symbolic import IdentityMapper, WalkMapper, CombineMapper from loopy.kernel.data import ( InstructionBase, MultiAssignmentBase, Assignment, @@ -43,6 +42,8 @@ from six.moves import range, zip, intern import re +from functools import reduce + import logging logger = logging.getLogger(__name__) @@ -1880,16 +1881,22 @@ class FunctionScoper(IdentityMapper): return IdentityMapper.map_call(self, expr) -class ScopedFunctionCollector(Collector): +class ScopedFunctionCollector(CombineMapper): """ This mapper would collect all the instances of :class:`ScopedFunction` occurring in the expression and written all of them as a :class:`set`. """ + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) def map_scoped_function(self, expr): - return set([expr.name]) + return frozenset([expr.name]) - def map_sub_array_ref(self, expr): - return set() + def map_constant(self, expr): + return frozenset() + + map_variable = map_constant + map_function_symbol = map_constant def scope_functions(kernel): diff --git a/loopy/preprocess.py b/loopy/preprocess.py index eedfca6f9..e7472ddd6 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2105,12 +2105,36 @@ def check_atomic_loads(kernel): # {{{ check for unscoped calls -class UnScopedCallCollector(Collector): +class UnScopedCallCollector(CombineMapper): + + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) + def map_call(self, expr): if not isinstance(expr.function, ScopedFunction): - return set([expr.function.name]) + return (frozenset([expr.function.name]) | + self.combine((self.rec(child) for child in expr.parameters))) + else: + return self.combine((self.rec(child) for child in expr.parameters)) + + def map_call_with_kwargs(self, expr): + if not isinstance(expr.function, ScopedFunction): + return (frozenset([expr.function.name]) | + self.combine((self.rec(child) for child in expr.parameters + + expr.kw_parameter.values()))) else: - return set() + return self.combine((self.rec(child) for child in + expr.parameters+expr.kw_parameters.values())) + + def map_scoped_function(self, expr): + return frozenset([expr.name]) + + def map_constant(self, expr): + return frozenset() + + map_variable = map_constant + map_function_symbol = map_constant def check_functions_are_scoped(kernel): -- GitLab