From 3729feb6c7578be126152aab8c09aed207140a16 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 25 Jun 2014 16:34:09 -0600 Subject: [PATCH] Add set_argument_order --- doc/reference.rst | 6 ++++-- loopy/__init__.py | 33 +++++++++++++++++++++++++++++++++ test/test_loopy.py | 8 ++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/doc/reference.rst b/doc/reference.rst index 098126546..cffed51f7 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -423,8 +423,10 @@ Library interface .. autofunction:: register_function_manglers -Argument types -^^^^^^^^^^^^^^ +Arguments +^^^^^^^^^ + +.. autofunction:: set_argument_order .. autofunction:: add_dtypes diff --git a/loopy/__init__.py b/loopy/__init__.py index 4367374cf..fc42b066d 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1441,4 +1441,37 @@ def make_copy_kernel(new_dim_tags, old_dim_tags=None): # }}} +# {{{ set argument order + +def set_argument_order(kernel, arg_names): + """ + :arg arg_names: A list (or comma-separated string) or argument + names. All arguments must be in this list. + """ + + if isinstance(arg_names, str): + arg_names = arg_names.split(",") + + new_args = [] + old_arg_dict = kernel.arg_dict.copy() + + for arg_name in arg_names: + try: + arg = old_arg_dict.pop(arg_name) + except KeyError: + raise LoopyError("unknown argument '%s'" + % arg_name) + + new_args.append(arg) + + if old_arg_dict: + raise LoopyError("incomplete argument list passed " + "to set_argument_order. Left over: '%s'" + % ", ".join(arg_name for arg_name in old_arg_dict)) + + return kernel.copy(args=new_args) + +# }}} + + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index f8cbbc2ae..cb9b7f45f 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1739,6 +1739,14 @@ def test_make_copy_kernel(ctx_factory): assert (a1 == a3).all() +def test_set_arg_order(): + knl = lp.make_kernel( + "{ [i,j]: 0<=i,j<n }", + "out[i,j] = a[i]*b[j]") + + knl = lp.set_argument_order(knl, "out,a,n,b") + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab