From 48fb52c87df68628533cbdd02a69dbb7c0997872 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 5 Jun 2013 02:40:26 -0400
Subject: [PATCH] Tests passing with data tagging.

---
 MEMO                     |  2 +-
 doc/reference.rst        |  5 +++++
 loopy/__init__.py        | 33 +++++++++++++++++++++++++++++++++
 loopy/compiled.py        |  2 +-
 loopy/kernel/array.py    | 40 +++++++++++++++++++++++++---------------
 loopy/kernel/creation.py |  2 +-
 test/test_linalg.py      |  2 +-
 7 files changed, 67 insertions(+), 19 deletions(-)

diff --git a/MEMO b/MEMO
index 170019de5..8b9700092 100644
--- a/MEMO
+++ b/MEMO
@@ -53,7 +53,7 @@ To-do
 - rename IndexTag -> InameTag
 
 - Data implementation tags
-  TODO further:
+  - retag semantics once strides have been computed
   - turn base_indices into offset
   - vectorization
   - automatic copies
diff --git a/doc/reference.rst b/doc/reference.rst
index 2ada2947a..47805c2b1 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -201,8 +201,13 @@ Caching, Precomputation and Prefetching
 
     Uses :func:`extract_subst` and :func:`precompute`.
 
+Influencing data access
+^^^^^^^^^^^^^^^^^^^^^^^
+
 .. autofunction:: change_arg_to_image
 
+.. autofunction:: tag_data_axis
+
 Padding
 ^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 60954f9cb..b4c7181d1 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1035,7 +1035,40 @@ def change_arg_to_image(knl, name):
 # }}}
 
 
+# {{{ tag data axis
 
+def tag_data_axis(knl, ary_name, axis, tag):
+    if ary_name in knl.temporary_variables:
+        ary = knl.temporary_variables[ary_name]
+    elif ary_name in knl.arg_dict:
+        ary = knl.arg_dict[ary_name]
+    else:
+        raise NameError("array '%s' was not found" % ary_name)
+
+    new_dim_tags = list(ary.dim_tags)
+    from loopy.kernel.array import parse_array_dim_tag
+    new_dim_tags[axis] = parse_array_dim_tag(tag)
+
+    ary = ary.copy(dim_tags=tuple(new_dim_tags))
+
+    if ary_name in knl.temporary_variables:
+        new_tv = knl.temporary_variables.copy()
+        new_tv[ary_name] = ary
+        return knl.copy(temporary_variables=new_tv)
+
+    elif ary_name in knl.arg_dict:
+        new_args = []
+        for arg in knl.args:
+            if arg.name == ary_name:
+                new_args.append(ary)
+            else:
+                new_args.append(arg)
+
+        return knl.copy(args=new_args)
 
+    else:
+        raise NameError("array '%s' was not found" % ary_name)
+
+# }}}
 
 # vim: foldmethod=marker
diff --git a/loopy/compiled.py b/loopy/compiled.py
index 911919b3a..8312b99d2 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -500,7 +500,7 @@ def make_ref_args(kernel, cl_arg_info, queue, parameters, fill_value):
                 ref_args[arg.name] = ary
             else:
                 fill_rand(storage_array)
-                if isinstance(arg, ImageArg):
+                if arg.arg_class is ImageArg:
                     # must be contiguous
                     ref_args[arg.name] = cl.image_from_array(
                             queue.context, ary.get())
diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index 831996ba2..47707fb07 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -122,7 +122,7 @@ PADDED_STRIDE_TAG = re.compile(r"^([a-zA-Z]+)\(pad=(.*)\)$")
 TARGET_AXIS_RE = re.compile(r"->([0-9])$")
 
 
-def parse_array_dim_tag(tag):
+def parse_array_dim_tag(tag, default_target_axis=0):
     if isinstance(tag, ArrayDimImplementationTag):
         return tag
 
@@ -144,7 +144,7 @@ def parse_array_dim_tag(tag):
         target_axis = int(target_axis_match.group(1))
         tag = tag[:target_axis_match.start()]
     else:
-        target_axis = 0
+        target_axis = default_target_axis
 
     if tag in ["c", "C", "f", "F"]:
         return ComputedStrideArrayDimTag(tag, target_axis=target_axis)
@@ -162,16 +162,19 @@ def parse_array_dim_tag(tag):
         return ComputedStrideArrayDimTag(order, pad, target_axis=target_axis)
 
 
-def parse_array_dim_tags(dim_tags):
+def parse_array_dim_tags(dim_tags, use_increasing_target_axes=False):
     if isinstance(dim_tags, str):
         dim_tags = dim_tags.split(",")
 
-    def parse_dim_tag_if_necessary(dt):
-        if isinstance(dt, str):
-            dt = parse_array_dim_tag(dt)
-        return dt
+    default_target_axis = 0
 
-    return [parse_dim_tag_if_necessary(dt) for dt in dim_tags]
+    result = []
+    for dt in dim_tags:
+        result.append(parse_array_dim_tag(dt, default_target_axis))
+        if use_increasing_target_axes:
+            default_target_axis += 1
+
+    return result
 
 
 def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes,
@@ -239,12 +242,13 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes,
 
         if fixed_stride_dim_tags[target_axis]:
             for i in fixed_stride_dim_tags[target_axis]:
-                dt = dim_tags[i]
-                new_dim_tags[i] = dt
+                dim_tag = dim_tags[i]
+                new_dim_tags[i] = dim_tag
         else:
             for i in computed_stride_dim_tags[target_axis]:
-                dt = dim_tags[i]
-                new_dim_tags[i] = FixedStrideArrayDimTag(stride_so_far)
+                dim_tag = dim_tags[i]
+                new_dim_tags[i] = FixedStrideArrayDimTag(stride_so_far,
+                        target_axis=dim_tag.target_axis)
 
                 if shape is None:
                     # unable to normalize without known shape
@@ -252,10 +256,10 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes,
 
                 stride_so_far *= shape[i]
 
-                if dt.pad_to is not None:
+                if dim_tag.pad_to is not None:
                     from pytools import div_ceil
                     stride_so_far = (
-                            div_ceil(stride_so_far, dt.pad_to)
+                            div_ceil(stride_so_far, dim_tag.pad_to)
                             * stride_so_far)
 
     # }}}
@@ -429,6 +433,11 @@ class ArrayBase(Record):
 
         # {{{ convert order to dim_tags
 
+        if order is None and self.max_target_axes > 1:
+            # FIXME: Hackety hack. ImageArgs need to generate dim_tags even
+            # if no order is specified. Plus they don't care that much.
+            order = "C"
+
         if dim_tags is None and num_user_axes is not None and order is not None:
             dim_tags = num_user_axes*[order]
             order = None
@@ -436,7 +445,8 @@ class ArrayBase(Record):
         # }}}
 
         if dim_tags is not None:
-            dim_tags = parse_array_dim_tags(dim_tags)
+            dim_tags = parse_array_dim_tags(dim_tags,
+                    use_increasing_target_axes=self.max_target_axes > 1)
 
             # {{{ find number of target axes
 
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index a4c9c0521..002b27aa5 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -705,7 +705,7 @@ def dup_args_and_expand_defines_in_shapes(kernel, defines):
 
             new_arg = arg.copy(name=arg_name)
             if isinstance(arg, ArrayBase):
-                new_arg = arg.map_exprs(
+                new_arg = new_arg.map_exprs(
                         lambda expr: expand_defines_in_expr(expr, defines))
 
             processed_args.append(new_arg)
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 15073433d..9bf916221 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -475,7 +475,7 @@ def test_image_matrix_mul(ctx_factory):
 
     lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
             op_count=[2*n**3/1e9], op_label=["GFlops"],
-            parameters={})
+            parameters={}, print_ref_code=True)
 
 
 def test_image_matrix_mul_ilp(ctx_factory):
-- 
GitLab