diff --git a/loopy/__init__.py b/loopy/__init__.py index d621f0591a9c39328ef58682b6b168b030807669..ba013365ce932819eeb83101c73e567e5ee62d20 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -123,7 +123,7 @@ from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call from loopy.type_inference import infer_unknown_types from loopy.preprocess import (preprocess_kernel, realize_reduction, - preprocess_program) + preprocess_program, infer_arg_descr) from loopy.schedule import ( generate_loop_schedules, get_one_scheduled_kernel, get_one_linearized_kernel) from loopy.statistics import (ToCountMap, ToCountPolynomialMap, CountGranularity, @@ -258,6 +258,8 @@ __all__ = [ "infer_unknown_types", "preprocess_kernel", "realize_reduction", "preprocess_program", + "infer_arg_descr", + "generate_loop_schedules", "get_one_scheduled_kernel", "get_one_linearized_kernel", "GeneratedProgram", "CodeGenerationResult", diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 784e8412a682f018a5e5bfb7672d75afa0e1eaf6..7d3df545d49531a5c26a561d6052e08adb510643 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -627,6 +627,7 @@ def generate_code_v2(program): :param program: An instance of :class:`loopy.Program`. """ + from loopy.kernel import LoopKernel from loopy.program import make_program from loopy.codegen.result import CodeGenerationResult @@ -634,9 +635,10 @@ def generate_code_v2(program): # {{{ cache retrieval from loopy import CACHING_ENABLED + from loopy.preprocess import prepare_for_caching if CACHING_ENABLED: - input_program = program + input_program = prepare_for_caching(program) try: result = code_gen_cache[input_program] logger.debug(f"Program with entrypoints {program.entrypoints}:" diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index ece606ddceaae5f09727ec46d5691a6a5decb2e4..d176488b632cd650799c0a9eb50dcbd2661588cd 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -396,7 +396,7 @@ class ArrayArg(ArrayBase, KernelArgument): if "address_space" not in kwargs: raise TypeError("'address_space' must be specified") - is_output_only = kwargs.pop("is_output_only", False) + is_output_only = kwargs.pop("is_output_only", None) if is_output_only is not None: warn("'is_output_only' is deprecated. Use 'is_output', 'is_input'" " instead.", DeprecationWarning, stacklevel=2) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 4acadcfe085ddd1b70ea7dbd503ab5ba8c8b8cb9..7f5979a04f8f29184447ecf7cfa433b793d86749 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -32,13 +32,9 @@ import islpy as isl from islpy import dim_type from loopy.diagnostic import LoopyError, warn_with_kernel from pytools import memoize_on_first_arg, natsorted -from loopy.symbolic import CombineMapper from loopy.kernel import LoopKernel from loopy.program import Program, iterate_over_kernels_if_given_program from loopy.kernel.function_interface import CallableKernel -from loopy.kernel.instruction import (MultiAssignmentBase, - _DataObliviousInstruction) -from functools import reduce import logging logger = logging.getLogger(__name__) @@ -1982,61 +1978,4 @@ def infer_args_are_input_output(kernel): # }}} - -# {{{ identify_root_kernel - -class CallCollector(CombineMapper): - def combine(self, values): - import operator - return reduce(operator.or_, values, frozenset()) - - def map_call(self, expr): - from pymbolic.primitives import CallWithKwargs - return self.rec(CallWithKwargs( - function=expr.function, parameters=expr.parameters, - kw_parameters={})) - - def map_call_with_kwargs(self, expr): - return (frozenset([expr.function.name]) | - self.combine(self.rec(child) for child in expr.parameters - + tuple(expr.kw_parameters.values()))) - - def map_constant(self, expr): - return frozenset() - - map_variable = map_constant - map_function_symbol = map_constant - map_tagged_variable = map_constant - map_type_cast = map_constant - - -def identify_root_kernel(kernels): - assert isinstance(kernels, list) - assert all(isinstance(knl, LoopKernel) for knl in kernels) - call_collector = CallCollector() - - def _calls_in_a_kernel(knl): - calls = set() - for insn in knl.instructions: - if isinstance(insn, MultiAssignmentBase): - calls = calls | call_collector(insn.expression) - elif isinstance(insn, _DataObliviousInstruction): - pass - else: - raise NotImplementedError() - - return calls - - all_calls = frozenset().union(*[_calls_in_a_kernel(knl) for knl in - kernels]) - - kernel_names = frozenset([knl.name for knl in kernels]) - - assert len(kernel_names - all_calls) == 1 - - root_knl_name, = (kernel_names - all_calls) - return root_knl_name - -# }}} - # vim: foldmethod=marker diff --git a/loopy/preprocess.py b/loopy/preprocess.py index dd14b0eb4c4f81088e2bcf69ffa100f8f12eab02..e377adc286f50ac77c461c52d5803bc4921e6919 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2175,7 +2175,7 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): def traverse_to_infer_arg_descr(kernel, callables_table): """ Returns a copy of *kernel* with the argument shapes and strides matching for - scoped functions in the *kernel*. Refer + resolved functions in the *kernel*. Refer :meth:`loopy.kernel.function_interface.InKernelCallable.with_descrs`. .. note:: @@ -2202,22 +2202,20 @@ def infer_arg_descr(program): :attr:`loopy.InKernelCallable.arg_id_to_descr` inferred for all the callables. """ - - from loopy.program import make_clbl_inf_ctx + from loopy.program import make_clbl_inf_ctx, resolve_callables from loopy.kernel.array import ArrayBase from loopy.kernel.function_interface import (ArrayArgDescriptor, ValueArgDescriptor) from loopy import auto, ValueArg + program = resolve_callables(program) + clbl_inf_ctx = make_clbl_inf_ctx(program.callables_table, program.entrypoints) renamed_entrypoints = set() for e in program.entrypoints: - # FIXME: Need to add docs which say that we need not add the current - # callable to the clbl_inf_ctx while writing the "with_types" - # This is treacherous, we should use traverse... instead. def _tuple_if_int(s): if isinstance(s, int): return s,