diff --git a/doc/reference.rst b/doc/reference.rst index 09812654649664e4e7fa22ad2a9d9709cdf4b940..cffed51f75afe592175ea8dfb17227d1153dd9ab 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -423,8 +423,10 @@ Library interface .. autofunction:: register_function_manglers -Argument types -^^^^^^^^^^^^^^ +Arguments +^^^^^^^^^ + +.. autofunction:: set_argument_order .. autofunction:: add_dtypes diff --git a/loopy/__init__.py b/loopy/__init__.py index 4367374cff4f342a2a6689ed47781c1f8e5e880d..fc42b066df94dc9e13599a434a53e2c8a909a3c5 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1441,4 +1441,37 @@ def make_copy_kernel(new_dim_tags, old_dim_tags=None): # }}} +# {{{ set argument order + +def set_argument_order(kernel, arg_names): + """ + :arg arg_names: A list (or comma-separated string) or argument + names. All arguments must be in this list. + """ + + if isinstance(arg_names, str): + arg_names = arg_names.split(",") + + new_args = [] + old_arg_dict = kernel.arg_dict.copy() + + for arg_name in arg_names: + try: + arg = old_arg_dict.pop(arg_name) + except KeyError: + raise LoopyError("unknown argument '%s'" + % arg_name) + + new_args.append(arg) + + if old_arg_dict: + raise LoopyError("incomplete argument list passed " + "to set_argument_order. Left over: '%s'" + % ", ".join(arg_name for arg_name in old_arg_dict)) + + return kernel.copy(args=new_args) + +# }}} + + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index f8cbbc2ae7a8434ee604b6f989263ebcbddd8306..cb9b7f45fcf6111bee6983dc6432a07503dcb627 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1739,6 +1739,14 @@ def test_make_copy_kernel(ctx_factory): assert (a1 == a3).all() +def test_set_arg_order(): + knl = lp.make_kernel( + "{ [i,j]: 0<=i,j<n }", + "out[i,j] = a[i]*b[j]") + + knl = lp.set_argument_order(knl, "out,a,n,b") + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])