diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index e4caa7148e7dd5fd352db6e101e45225892da29d..0bf9bb548f7c5f6cf688eea6f1aa4d2bdaba1139 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -337,6 +337,25 @@ def parse_domains(ctx, domains, defines): # }}} +# {{{ duplicate arguments with commas in name + +def duplicate_args_with_commas(kernel_args): + processed_args = [] + for arg in kernel_args: + if arg is Ellipsis or isinstance(arg, str): + processed_args.append(arg) + else: + for arg_name in arg.name.split(","): + if not arg_name.strip(): + continue + + processed_args.append(arg.copy(name=arg_name)) + + return processed_args + +# }}} + + # {{{ guess kernel args (if requested) class IndexRankFinder(WalkMapper): @@ -703,24 +722,19 @@ def check_for_reduction_inames_duplication_requests(kernel): # }}} -# {{{ duplicate arguments and expand defines in shapes +# {{{ expand defines in shapes -def dup_args_and_expand_defines_in_shapes(kernel, defines): +def expand_defines_in_shapes(kernel, defines): from loopy.kernel.array import ArrayBase from loopy.kernel.creation import expand_defines_in_expr processed_args = [] for arg in kernel.args: - for arg_name in arg.name.split(","): - if not arg_name.strip(): - continue - - new_arg = arg.copy(name=arg_name) - if isinstance(arg, ArrayBase): - new_arg = new_arg.map_exprs( - lambda expr: expand_defines_in_expr(expr, defines)) + if isinstance(arg, ArrayBase): + arg = arg.map_exprs( + lambda expr: expand_defines_in_expr(expr, defines)) - processed_args.append(new_arg) + processed_args.append(arg) return kernel.copy(args=processed_args) @@ -893,7 +907,8 @@ def make_kernel(device, domains, instructions, kernel_args=["..."], **kwargs): domains = parse_domains(isl_context, domains, defines) kernel_args = guess_kernel_args_if_requested(domains, instructions, - kwargs.get("temporary_variables", {}), substitutions, kernel_args, + kwargs.get("temporary_variables", {}), substitutions, + duplicate_args_with_commas(kernel_args), default_offset) from loopy.kernel import LoopKernel @@ -905,7 +920,7 @@ def make_kernel(device, domains, instructions, kernel_args=["..."], **kwargs): knl = tag_reduction_inames_as_sequential(knl) knl = create_temporaries(knl) knl = expand_cses(knl) - knl = dup_args_and_expand_defines_in_shapes(knl, defines) + knl = expand_defines_in_shapes(knl, defines) knl = guess_arg_shape_if_requested(knl, default_order) knl = apply_default_order_to_args(knl, default_order)