diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index c42db348234345a48efcb22b842fd114c5f65f14..f20bffcb368627d39664e2078c1502e4608a5724 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1615,29 +1615,6 @@ def expand_defines_in_shapes(kernel, defines): # }}} -# {{{ guess argument shapes - -def guess_arg_shape_if_requested(kernel, default_order): - new_args = [] - - import loopy as lp - from loopy.kernel.array import ArrayBase - from loopy.kernel.tools import guess_var_shape - - for arg in kernel.args: - if isinstance(arg, ArrayBase) and arg.shape is lp.auto: - shape = guess_var_shape(kernel, arg.name) - - if arg.shape is lp.auto: - arg = arg.copy(shape=shape) - - new_args.append(arg) - - return kernel.copy(args=new_args) - -# }}} - - # {{{ apply default_order to args def apply_default_order_to_args(kernel, default_order): @@ -1907,6 +1884,9 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): :arg fixed_parameters: A dictionary of *name*/*value* pairs, where *name* will be fixed to *value*. *name* may refer to :ref:`domain-parameters` or :ref:`arguments`. See also :func:`loopy.fix_parameters`. + :arg infer_shape: If *True*, array argument shapes will be inferred; if + *False*, shapes are left as `lp.auto` and can be inferred at the user's + request with `infer_arg_shapes`. :arg lang_version: The language version against which the kernel was written, a tuple. To ensure future compatibility, copy the current value of @@ -1954,6 +1934,7 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): target = kwargs.pop("target", None) seq_dependencies = kwargs.pop("seq_dependencies", False) fixed_parameters = kwargs.pop("fixed_parameters", {}) + infer_shape = kwargs.pop("infer_shape", True) if defines: from warnings import warn @@ -2149,7 +2130,9 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): knl = determine_shapes_of_temporaries(knl) knl = expand_defines_in_shapes(knl, defines) - knl = guess_arg_shape_if_requested(knl, default_order) + if infer_shape: + from loopy.transform.data import infer_arg_shapes + knl = infer_arg_shapes(knl) knl = apply_default_order_to_args(knl, default_order) knl = resolve_dependencies(knl) knl = apply_single_writer_depencency_heuristic(knl, warn_if_used=False) diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 5b1ee6ccafa4c3b76609f197cc31691de562aaa5..22dc6247c6900682dedcb4324ea8d134eb3ab44c 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -758,4 +758,26 @@ def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=No # }}} +# {{{ infer argument shapes + +def infer_arg_shapes(kernel): + new_args = [] + + import loopy as lp + from loopy.kernel.array import ArrayBase + from loopy.kernel.tools import guess_var_shape + + for arg in kernel.args: + if isinstance(arg, ArrayBase) and arg.shape is lp.auto: + shape = guess_var_shape(kernel, arg.name) + + if arg.shape is lp.auto: + arg = arg.copy(shape=shape) + + new_args.append(arg) + + return kernel.copy(args=new_args) + +# }}} + # vim: foldmethod=marker