diff --git a/loopy/__init__.py b/loopy/__init__.py index 7851f8a32926f2bd377da28016a3c520f2a4caf4..a2d1f971d8ea71b11dbdf2e849843a1e5d381b65 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1046,37 +1046,41 @@ def change_arg_to_image(knl, name): # {{{ tag data axes -def tag_data_axes(knl, ary_name, dim_tags): - if ary_name in knl.temporary_variables: - ary = knl.temporary_variables[ary_name] - elif ary_name in knl.arg_dict: - ary = knl.arg_dict[ary_name] - else: - raise NameError("array '%s' was not found" % ary_name) +def tag_data_axes(knl, ary_names, dim_tags): + for ary_name in ary_names.split(","): + ary_name = ary_name.strip() + if ary_name in knl.temporary_variables: + ary = knl.temporary_variables[ary_name] + elif ary_name in knl.arg_dict: + ary = knl.arg_dict[ary_name] + else: + raise NameError("array '%s' was not found" % ary_name) - from loopy.kernel.array import parse_array_dim_tags - new_dim_tags = parse_array_dim_tags(dim_tags, - use_increasing_target_axes=ary.max_target_axes > 1) + from loopy.kernel.array import parse_array_dim_tags + new_dim_tags = parse_array_dim_tags(dim_tags, + use_increasing_target_axes=ary.max_target_axes > 1) - ary = ary.copy(dim_tags=tuple(new_dim_tags)) + ary = ary.copy(dim_tags=tuple(new_dim_tags)) - if ary_name in knl.temporary_variables: - new_tv = knl.temporary_variables.copy() - new_tv[ary_name] = ary - return knl.copy(temporary_variables=new_tv) + if ary_name in knl.temporary_variables: + new_tv = knl.temporary_variables.copy() + new_tv[ary_name] = ary + knl = knl.copy(temporary_variables=new_tv) - elif ary_name in knl.arg_dict: - new_args = [] - for arg in knl.args: - if arg.name == ary_name: - new_args.append(ary) - else: - new_args.append(arg) + elif ary_name in knl.arg_dict: + new_args = [] + for arg in knl.args: + if arg.name == ary_name: + new_args.append(ary) + else: + new_args.append(arg) - return knl.copy(args=new_args) + knl = knl.copy(args=new_args) - else: - raise NameError("array '%s' was not found" % ary_name) + else: + raise NameError("array '%s' was not found" % ary_name) + + return knl # }}}