From 22ac610b65fc9b60309ac87eab5eaf6b64f64bbe Mon Sep 17 00:00:00 2001 From: zachjweiner Date: Thu, 13 Dec 2018 17:13:02 -0600 Subject: [PATCH] replaces guess_arg_shape_if_requested with infer_arg_shapes and adds infer_shape option to make_kernel --- loopy/kernel/creation.py | 31 +++++++------------------------ loopy/transform/data.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index c42db3482..f20bffcb3 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1615,29 +1615,6 @@ def expand_defines_in_shapes(kernel, defines): # }}} -# {{{ guess argument shapes - -def guess_arg_shape_if_requested(kernel, default_order): - new_args = [] - - import loopy as lp - from loopy.kernel.array import ArrayBase - from loopy.kernel.tools import guess_var_shape - - for arg in kernel.args: - if isinstance(arg, ArrayBase) and arg.shape is lp.auto: - shape = guess_var_shape(kernel, arg.name) - - if arg.shape is lp.auto: - arg = arg.copy(shape=shape) - - new_args.append(arg) - - return kernel.copy(args=new_args) - -# }}} - - # {{{ apply default_order to args def apply_default_order_to_args(kernel, default_order): @@ -1907,6 +1884,9 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): :arg fixed_parameters: A dictionary of *name*/*value* pairs, where *name* will be fixed to *value*. *name* may refer to :ref:`domain-parameters` or :ref:`arguments`. See also :func:`loopy.fix_parameters`. + :arg infer_shape: If *True*, array argument shapes will be inferred; if + *False*, shapes are left as `lp.auto` and can be inferred at the user's + request with `infer_arg_shapes`. :arg lang_version: The language version against which the kernel was written, a tuple. To ensure future compatibility, copy the current value of @@ -1954,6 +1934,7 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): target = kwargs.pop("target", None) seq_dependencies = kwargs.pop("seq_dependencies", False) fixed_parameters = kwargs.pop("fixed_parameters", {}) + infer_shape = kwargs.pop("infer_shape", True) if defines: from warnings import warn @@ -2149,7 +2130,9 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): knl = determine_shapes_of_temporaries(knl) knl = expand_defines_in_shapes(knl, defines) - knl = guess_arg_shape_if_requested(knl, default_order) + if infer_shape: + from loopy.transform.data import infer_arg_shapes + knl = infer_arg_shapes(knl) knl = apply_default_order_to_args(knl, default_order) knl = resolve_dependencies(knl) knl = apply_single_writer_depencency_heuristic(knl, warn_if_used=False) diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 5b1ee6cca..22dc6247c 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -758,4 +758,26 @@ def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=No # }}} +# {{{ infer argument shapes + +def infer_arg_shapes(kernel): + new_args = [] + + import loopy as lp + from loopy.kernel.array import ArrayBase + from loopy.kernel.tools import guess_var_shape + + for arg in kernel.args: + if isinstance(arg, ArrayBase) and arg.shape is lp.auto: + shape = guess_var_shape(kernel, arg.name) + + if arg.shape is lp.auto: + arg = arg.copy(shape=shape) + + new_args.append(arg) + + return kernel.copy(args=new_args) + +# }}} + # vim: foldmethod=marker -- GitLab