diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 883db10dc2a175a0702b66643b6e907a721ae936..3a2f888f8c69d6ae4afef51c253f4294fcd87f1e 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1855,7 +1855,8 @@ class FunctionScoper(IdentityMapper): def map_call(self, expr): from loopy.symbolic import ScopedFunction - if expr.function.name in self.function_ids: + if not isinstance(expr.function, ScopedFunction) and ( + expr.function.name in self.function_ids): # The function is one of the known function hence scoping it. from pymbolic.primitives import Call @@ -1868,9 +1869,10 @@ class FunctionScoper(IdentityMapper): return IdentityMapper.map_call(self, expr) def map_call_with_kwargs(self, expr): - if expr.function.name in self.function_ids: + from loopy.symbolic import ScopedFunction + if not isinstance(expr.function, ScopedFunction) and ( + expr.function.name in self.function_ids): from pymbolic.primitives import CallWithKwargs - from loopy.symbolic import ScopedFunction return CallWithKwargs( ScopedFunction(expr.function.name), tuple(self.rec(child) @@ -1887,6 +1889,10 @@ class FunctionScoper(IdentityMapper): from pymbolic.primitives import Variable from loopy.symbolic import ScopedFunction + if isinstance(expr.function, ScopedFunction): + # we have already scoped this function. + return IdentityMapper.map_reduction(self, expr) + mapped_inames = [self.rec(Variable(iname)) for iname in expr.inames] new_inames = [] @@ -1915,13 +1921,20 @@ class ScopedFunctionCollector(CombineMapper): """ This mapper would collect all the instances of :class:`ScopedFunction` occurring in the expression and written all of them as a :class:`set`. """ + def __init__(self, already_scoped_functions={}): + self.already_scoped_functions = already_scoped_functions + def combine(self, values): import operator return reduce(operator.or_, values, frozenset()) def map_scoped_function(self, expr): from loopy.kernel.function_interface import CallableOnScalar - return frozenset([(expr.name, CallableOnScalar(expr.name))]) + if expr.name in self.already_scoped_functions: + # functions is already scoped + return frozenset() + else: + return frozenset([(expr.name, CallableOnScalar(expr.name))]) def map_reduction(self, expr): from loopy.kernel.function_interface import (CallableOnScalar, @@ -1931,6 +1944,10 @@ class ScopedFunctionCollector(CombineMapper): # Refer to `map_reduction` subroutine of `FunctionScoper`. assert expr.function.name[-7:] == "_reduce" + if expr.function.name in self.already_scoped_functions: + # the function is already scoped + return self.rec(expr.expr) + callable_reduction = CallableReduction(expr.function.name[:-7]) # sanity checks @@ -1962,7 +1979,6 @@ class ScopedFunctionCollector(CombineMapper): hidden_function = callable_reduction.operation.hidden_function() if hidden_function is not None: - return ( frozenset([(expr.function.name, callable_reduction), (hidden_function, CallableOnScalar(hidden_function))]) | @@ -1986,15 +2002,17 @@ def scope_functions(kernel): from loopy.kernel.instruction import CInstruction, _DataObliviousInstruction function_scoper = FunctionScoper(func_ids) - scoped_function_collector = ScopedFunctionCollector() - scoped_functions = set() + scoped_function_collector = ScopedFunctionCollector( + kernel.scoped_functions) + new_scoped_functions = set() new_insns = [] for insn in kernel.instructions: if isinstance(insn, (MultiAssignmentBase, CInstruction)): new_insn = insn.copy(expression=function_scoper(insn.expression)) - scoped_functions.update(scoped_function_collector(new_insn.expression)) + new_scoped_functions.update(scoped_function_collector( + new_insn.expression)) new_insns.append(new_insn) elif isinstance(insn, _DataObliviousInstruction): new_insns.append(insn) @@ -2002,19 +2020,21 @@ def scope_functions(kernel): raise NotImplementedError("scope_functions not implemented for %s" % type(insn)) - scoped_substitutions = {} + substitutions_with_scoped_expr = {} for name, rule in kernel.substitutions.items(): scoped_rule = rule.copy( expression=function_scoper(rule.expression)) - scoped_substitutions[name] = scoped_rule - scoped_functions.update(scoped_function_collector(scoped_rule.expression)) + substitutions_with_scoped_expr[name] = scoped_rule + new_scoped_functions.update(scoped_function_collector( + scoped_rule.expression)) # Need to combine the scoped functions into a dict - scoped_function_dict = dict(scoped_functions) + updated_scoped_functions = kernel.scoped_functions.copy() + updated_scoped_functions.update(dict(new_scoped_functions)) return kernel.copy(instructions=new_insns, - scoped_functions=scoped_function_dict, - substitutions=scoped_substitutions) + scoped_functions=updated_scoped_functions, + substitutions=substitutions_with_scoped_expr) # }}}