From 165f8f9b9b8eba46469fa7c5e42543d0e3bfdfce Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 13 Jun 2013 23:04:39 -0400 Subject: [PATCH] Accept multiple (comma-separated) array names in tag_data_axes --- loopy/__init__.py | 54 +++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index 7851f8a32..a2d1f971d 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1046,37 +1046,41 @@ def change_arg_to_image(knl, name): # {{{ tag data axes -def tag_data_axes(knl, ary_name, dim_tags): - if ary_name in knl.temporary_variables: - ary = knl.temporary_variables[ary_name] - elif ary_name in knl.arg_dict: - ary = knl.arg_dict[ary_name] - else: - raise NameError("array '%s' was not found" % ary_name) +def tag_data_axes(knl, ary_names, dim_tags): + for ary_name in ary_names.split(","): + ary_name = ary_name.strip() + if ary_name in knl.temporary_variables: + ary = knl.temporary_variables[ary_name] + elif ary_name in knl.arg_dict: + ary = knl.arg_dict[ary_name] + else: + raise NameError("array '%s' was not found" % ary_name) - from loopy.kernel.array import parse_array_dim_tags - new_dim_tags = parse_array_dim_tags(dim_tags, - use_increasing_target_axes=ary.max_target_axes > 1) + from loopy.kernel.array import parse_array_dim_tags + new_dim_tags = parse_array_dim_tags(dim_tags, + use_increasing_target_axes=ary.max_target_axes > 1) - ary = ary.copy(dim_tags=tuple(new_dim_tags)) + ary = ary.copy(dim_tags=tuple(new_dim_tags)) - if ary_name in knl.temporary_variables: - new_tv = knl.temporary_variables.copy() - new_tv[ary_name] = ary - return knl.copy(temporary_variables=new_tv) + if ary_name in knl.temporary_variables: + new_tv = knl.temporary_variables.copy() + new_tv[ary_name] = ary + knl = knl.copy(temporary_variables=new_tv) - elif ary_name in knl.arg_dict: - new_args = [] - for arg in knl.args: - if arg.name == ary_name: - new_args.append(ary) - else: - new_args.append(arg) + elif ary_name in knl.arg_dict: + new_args = [] + for arg in knl.args: + if arg.name == ary_name: + new_args.append(ary) + else: + new_args.append(arg) - return knl.copy(args=new_args) + knl = knl.copy(args=new_args) - else: - raise NameError("array '%s' was not found" % ary_name) + else: + raise NameError("array '%s' was not found" % ary_name) + + return knl # }}} -- GitLab