From 3729feb6c7578be126152aab8c09aed207140a16 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 25 Jun 2014 16:34:09 -0600
Subject: [PATCH] Add set_argument_order

---
 doc/reference.rst  |  6 ++++--
 loopy/__init__.py  | 33 +++++++++++++++++++++++++++++++++
 test/test_loopy.py |  8 ++++++++
 3 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index 098126546..cffed51f7 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -423,8 +423,10 @@ Library interface
 
 .. autofunction:: register_function_manglers
 
-Argument types
-^^^^^^^^^^^^^^
+Arguments
+^^^^^^^^^
+
+.. autofunction:: set_argument_order
 
 .. autofunction:: add_dtypes
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 4367374cf..fc42b066d 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1441,4 +1441,37 @@ def make_copy_kernel(new_dim_tags, old_dim_tags=None):
 # }}}
 
 
+# {{{ set argument order
+
+def set_argument_order(kernel, arg_names):
+    """
+    :arg arg_names: A list (or comma-separated string) or argument
+        names. All arguments must be in this list.
+    """
+
+    if isinstance(arg_names, str):
+        arg_names = arg_names.split(",")
+
+    new_args = []
+    old_arg_dict = kernel.arg_dict.copy()
+
+    for arg_name in arg_names:
+        try:
+            arg = old_arg_dict.pop(arg_name)
+        except KeyError:
+            raise LoopyError("unknown argument '%s'"
+                    % arg_name)
+
+        new_args.append(arg)
+
+    if old_arg_dict:
+        raise LoopyError("incomplete argument list passed "
+                "to set_argument_order. Left over: '%s'"
+                % ", ".join(arg_name for arg_name in old_arg_dict))
+
+    return kernel.copy(args=new_args)
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index f8cbbc2ae..cb9b7f45f 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1739,6 +1739,14 @@ def test_make_copy_kernel(ctx_factory):
     assert (a1 == a3).all()
 
 
+def test_set_arg_order():
+    knl = lp.make_kernel(
+            "{ [i,j]: 0<=i,j<n }",
+            "out[i,j] = a[i]*b[j]")
+
+    knl = lp.set_argument_order(knl, "out,a,n,b")
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab