From 37ea03c8fea80342397a2ac41162f4fd4fd87bc6 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 12 Jun 2014 10:23:13 +0100
Subject: [PATCH] Implement make_copy_kernel

---
 doc/reference.rst  |  2 ++
 loopy/__init__.py  | 49 ++++++++++++++++++++++++++++++++++++++++++++++
 test/test_loopy.py | 30 +++++++++++++++++++++++-----
 3 files changed, 76 insertions(+), 5 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index 4892b1ad7..3ad6099fe 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -320,6 +320,8 @@ function, which is responsible for creating kernels:
 
 .. autofunction:: make_kernel
 
+.. autofunction:: make_copy_kernel
+
 Transforming Kernels
 --------------------
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 5d9ad93e7..3d6f71bec 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1381,4 +1381,53 @@ class CacheMode(object):
 # }}}
 
 
+# {{{ data layout change
+
+def make_copy_kernel(new_dim_tags, old_dim_tags=None):
+    """Returns a :class:`LoopKernel` that changes the data layout
+    of a variable (called "input") to the new layout specified by
+    *new_dim_tags* from the one specified by *old_dim_tags*.
+    *old_dim_tags* defaults to an all-C layout of the same rank
+    as the one given by *new_dim_tags*.
+    """
+
+    from loopy.kernel.array import (parse_array_dim_tags,
+            SeparateArrayArrayDimTag, VectorArrayDimTag)
+    new_dim_tags = parse_array_dim_tags(new_dim_tags)
+
+    rank = len(new_dim_tags)
+    if old_dim_tags is None:
+        old_dim_tags = parse_array_dim_tags(",".join(rank * ["c"]))
+    elif isinstance(old_dim_tags, str):
+        old_dim_tags = parse_array_dim_tags(old_dim_tags)
+
+    indices = ["i%d" % i for i in range(rank)]
+    shape = ["n%d" % i for i in range(rank)]
+    commad_indices = ", ".join(indices)
+    bounds = " and ".join(
+            "0<=%s<%s" % (ind, shape_i)
+            for ind, shape_i in zip(indices, shape))
+
+    set_str = "{[%s]: %s}" % (
+                commad_indices,
+                bounds
+                )
+    result = make_kernel(set_str,
+            "output[%s] = input[%s]"
+            % (commad_indices, commad_indices))
+
+    result = tag_data_axes(result, "input", old_dim_tags)
+    result = tag_data_axes(result, "output", new_dim_tags)
+
+    unrolled_tags = (SeparateArrayArrayDimTag, VectorArrayDimTag)
+    for i in range(rank):
+        if (isinstance(new_dim_tags[i], unrolled_tags)
+                or isinstance(old_dim_tags[i], unrolled_tags)):
+            result = tag_inames(result, {indices[i]: "unr"})
+
+    return result
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index d17ec1c0c..f8cbbc2ae 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1689,9 +1689,6 @@ def test_multiple_writes_to_local_temporary(ctx_factory):
 
 
 def test_fd_demo(ctx_factory):
-    ctx = ctx_factory()
-    queue = cl.CommandQueue(ctx)
-
     knl = lp.make_kernel(
         "{[i,j]: 0<=i,j<n}",
         "result[i,j] = u[i, j]**2 + -1 + (-4)*u[i + 1, j + 1] \
@@ -1706,8 +1703,8 @@ def test_fd_demo(ctx_factory):
             ["i_inner", "j_inner"],
             fetch_bounding_box=True)
 
-    n = 1000
-    u = cl.clrandom.rand(queue, (n+2, n+2), dtype=np.float32)
+    #n = 1000
+    #u = cl.clrandom.rand(queue, (n+2, n+2), dtype=np.float32)
 
     knl = lp.set_options(knl, write_cl=True)
     knl = lp.add_and_infer_dtypes(knl, dict(u=np.float32))
@@ -1719,6 +1716,29 @@ def test_fd_demo(ctx_factory):
     assert "double" not in code
 
 
+def test_make_copy_kernel(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    intermediate_format = "f,f,sep"
+
+    a1 = np.random.randn(1024, 4, 3)
+
+    cknl1 = lp.make_copy_kernel(intermediate_format)
+
+    cknl1 = lp.fix_parameters(cknl1, n2=3)
+
+    cknl1 = lp.set_options(cknl1, write_cl=True)
+    evt, a2 = cknl1(queue, input=a1)
+
+    cknl2 = lp.make_copy_kernel("c,c,c", intermediate_format)
+    cknl2 = lp.fix_parameters(cknl2, n2=3)
+
+    evt, a3 = cknl2(queue, input=a2)
+
+    assert (a1 == a3).all()
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab