From 165f8f9b9b8eba46469fa7c5e42543d0e3bfdfce Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 13 Jun 2013 23:04:39 -0400
Subject: [PATCH] Accept multiple (comma-separated) array names in
 tag_data_axes

---
 loopy/__init__.py | 54 +++++++++++++++++++++++++----------------------
 1 file changed, 29 insertions(+), 25 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 7851f8a32..a2d1f971d 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1046,37 +1046,41 @@ def change_arg_to_image(knl, name):
 
 # {{{ tag data axes
 
-def tag_data_axes(knl, ary_name, dim_tags):
-    if ary_name in knl.temporary_variables:
-        ary = knl.temporary_variables[ary_name]
-    elif ary_name in knl.arg_dict:
-        ary = knl.arg_dict[ary_name]
-    else:
-        raise NameError("array '%s' was not found" % ary_name)
+def tag_data_axes(knl, ary_names, dim_tags):
+    for ary_name in ary_names.split(","):
+        ary_name = ary_name.strip()
+        if ary_name in knl.temporary_variables:
+            ary = knl.temporary_variables[ary_name]
+        elif ary_name in knl.arg_dict:
+            ary = knl.arg_dict[ary_name]
+        else:
+            raise NameError("array '%s' was not found" % ary_name)
 
-    from loopy.kernel.array import parse_array_dim_tags
-    new_dim_tags = parse_array_dim_tags(dim_tags,
-            use_increasing_target_axes=ary.max_target_axes > 1)
+        from loopy.kernel.array import parse_array_dim_tags
+        new_dim_tags = parse_array_dim_tags(dim_tags,
+                use_increasing_target_axes=ary.max_target_axes > 1)
 
-    ary = ary.copy(dim_tags=tuple(new_dim_tags))
+        ary = ary.copy(dim_tags=tuple(new_dim_tags))
 
-    if ary_name in knl.temporary_variables:
-        new_tv = knl.temporary_variables.copy()
-        new_tv[ary_name] = ary
-        return knl.copy(temporary_variables=new_tv)
+        if ary_name in knl.temporary_variables:
+            new_tv = knl.temporary_variables.copy()
+            new_tv[ary_name] = ary
+            knl = knl.copy(temporary_variables=new_tv)
 
-    elif ary_name in knl.arg_dict:
-        new_args = []
-        for arg in knl.args:
-            if arg.name == ary_name:
-                new_args.append(ary)
-            else:
-                new_args.append(arg)
+        elif ary_name in knl.arg_dict:
+            new_args = []
+            for arg in knl.args:
+                if arg.name == ary_name:
+                    new_args.append(ary)
+                else:
+                    new_args.append(arg)
 
-        return knl.copy(args=new_args)
+            knl = knl.copy(args=new_args)
 
-    else:
-        raise NameError("array '%s' was not found" % ary_name)
+        else:
+            raise NameError("array '%s' was not found" % ary_name)
+
+    return knl
 
 # }}}
 
-- 
GitLab