diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index f47144f94f37f39f41fc51b25c65b4e35b8883c1..190a80d3b345b2ec290e08e13b5dc1e257de5e92 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1835,32 +1835,44 @@ def apply_single_writer_depencency_heuristic(kernel, warn_if_used=True): class FunctionScoper(IdentityMapper): + """ + Subclass of :class:`IdentityMapper` which converts functions known to + the kernel at to instances of :class:`ScopedFunction`. + + .. _example: + + If given an expression of the form `sin(x) + unknown_function(y) + + log(z)`, then the mapper would return `ScopedFunction('sin')(x) + + unknown_function(y) + ScopedFunction('log')(z)`. Since the + `unknown_function` is not known to the kernel it is not marked as a + `ScopedFunction`. + """ def __init__(self, function_ids): self.function_ids = function_ids def map_call(self, expr): + from loopy.symbolic import ScopedFunction if expr.function.name in self.function_ids: - # 1. need to change the function to ScopedFunction instead of Variable + # The function is one of the known function hence scoping it. from pymbolic.primitives import Call - from loopy.symbolic import ScopedFunction - return super(FunctionScoper, self).map_call( - Call(function=ScopedFunction(expr.function.name), - parameters=expr.parameters)) - - else: - return super(FunctionScoper, self).map_call(expr) + return Call( + ScopedFunction(expr.function.name), + tuple(self.rec(child) + for child in expr.parameters)) def map_call_with_kwargs(self, expr): if expr.function.name in self.function_ids: from pymbolic.primitives import CallWithKwargs from loopy.symbolic import ScopedFunction - return super(FunctionScoper, self).map_call_with_kwargs( - CallWithKwargs(function=ScopedFunction(expr.function.name), - parameters=expr.parameters, - kw_parameters=expr.kw_parameters)) - else: - return super(FunctionScoper, self).map_call_with_kwargs(expr) + return CallWithKwargs( + ScopedFunction(expr.function.name), + tuple(self.rec(child) + for child in expr.parameters), + dict( + (key, self.rec(val)) + for key, val in six.iteritems(expr.kw_parameters)) + ) class ScopedFunctionCollector(Collector): @@ -1868,6 +1880,8 @@ class ScopedFunctionCollector(Collector): def map_scoped_function(self, expr): return set([expr.name]) + map_sub_array_ref = Collector.map_constant + def scope_functions(kernel): func_ids = kernel.function_identifiers.copy() @@ -1887,7 +1901,7 @@ def scope_functions(kernel): elif isinstance(insn, _DataObliviousInstruction): new_insns.append(insn) else: - raise NotImplementedError("scope_function not implemented for %s" % + raise NotImplementedError("scope_functions not implemented for %s" % type(insn)) # Need to combine the scoped functions into a dict @@ -2235,8 +2249,6 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): check_written_variable_names(knl) # Function Lookup - # TODO: here I add my function for function_lookup. Lol. realize the UN-inteded - # pun knl = scope_functions(knl) from loopy.preprocess import prepare_for_caching