From 53995111eb1dd3274966d1c41382e3fc72f8ebdc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 18 Dec 2013 18:35:06 -0800
Subject: [PATCH] Allow string-o-names to specify argument order

---
 loopy/kernel/creation.py | 186 +++++++++++++++++++++++----------------
 1 file changed, 109 insertions(+), 77 deletions(-)

diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index f68bc5042..824a8b8f4 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -401,104 +401,130 @@ class IndexRankFinder(WalkMapper):
                 self.index_ranks.append(len(expr.index))
 
 
-def guess_kernel_args_if_requested(domains, instructions, temporary_variables,
-        subst_rules, kernel_args, default_offset):
-    # Ellipsis is syntactically allowed in Py3.
-    if "..." not in kernel_args and Ellipsis not in kernel_args:
-        return kernel_args
+class ArgumentGuesser:
+    def __init__(self, domains, instructions, temporary_variables,
+            subst_rules, default_offset):
+        self.domains = domains
+        self.instructions = instructions
+        self.temporary_variables = temporary_variables
+        self.subst_rules = subst_rules
+        self.default_offset = default_offset
+
+        from loopy.symbolic import SubstitutionRuleExpander
+        self.submap = SubstitutionRuleExpander(subst_rules)
+
+        self.all_inames = set()
+        for dom in domains:
+            self.all_inames.update(dom.get_var_names(dim_type.set))
+
+        all_params = set()
+        for dom in domains:
+            all_params.update(dom.get_var_names(dim_type.param))
+        self.all_params = all_params - self.all_inames
+
+        self.all_names = set()
+        self.all_written_names = set()
+        from loopy.symbolic import get_dependencies
+        for insn in instructions:
+            if isinstance(insn, ExpressionInstruction):
+                (assignee_var_name, _), = insn.assignees_and_indices()
+                self.all_written_names.add(assignee_var_name)
+                self.all_names.update(get_dependencies(
+                    self.submap(insn.assignee, insn.id)))
+                self.all_names.update(get_dependencies(
+                    self.submap(insn.expression, insn.id)))
 
-    kernel_args = [arg for arg in kernel_args
-            if arg is not Ellipsis and arg != "..."]
+    def find_index_rank(self, name):
+        irf = IndexRankFinder(name)
 
-    from loopy.symbolic import SubstitutionRuleExpander
-    submap = SubstitutionRuleExpander(subst_rules)
+        for insn in self.instructions:
+            insn.with_transformed_expressions(
+                    lambda expr: irf(self.submap(expr, insn.id)))
 
-    # {{{ find names that are *not* arguments
+        if not irf.index_ranks:
+            return 0
+        else:
+            from pytools import single_valued
+            return single_valued(irf.index_ranks)
 
-    all_inames = set()
-    for dom in domains:
-        all_inames.update(dom.get_var_names(dim_type.set))
+    def make_new_arg(self, arg_name):
+        arg_name = arg_name.strip()
 
-    temp_var_names = set(temporary_variables.iterkeys())
+        from loopy.kernel.data import ValueArg, GlobalArg
+        import loopy as lp
 
-    for insn in instructions:
-        if isinstance(insn, ExpressionInstruction):
-            if insn.temp_var_type is not None:
-                (assignee_var_name, _), = insn.assignees_and_indices()
-                temp_var_names.add(assignee_var_name)
+        if arg_name in self.all_params:
+            return ValueArg(arg_name)
 
-    # }}}
+        if arg_name in self.all_written_names:
+            # It's not a temp var, and thereby not a domain parameter--the only
+            # other writable type of variable is an argument.
 
-    # {{{ find existing and new arg names
+            return GlobalArg(arg_name,
+                    shape=lp.auto, offset=self.default_offset)
 
-    existing_arg_names = set()
-    for arg in kernel_args:
-        existing_arg_names.add(arg.name)
+        irank = self.find_index_rank(arg_name)
+        if irank == 0:
+            # read-only, no indices
+            return ValueArg(arg_name)
+        else:
+            return GlobalArg(arg_name, shape=lp.auto, offset=self.default_offset)
 
-    not_new_arg_names = existing_arg_names | temp_var_names | all_inames
+    def convert_names_to_full_args(self, kernel_args):
+        new_kernel_args = []
 
-    all_names = set()
-    all_written_names = set()
-    from loopy.symbolic import get_dependencies
-    for insn in instructions:
-        if isinstance(insn, ExpressionInstruction):
-            (assignee_var_name, _), = insn.assignees_and_indices()
-            all_written_names.add(assignee_var_name)
-            all_names.update(get_dependencies(submap(insn.assignee, insn.id)))
-            all_names.update(get_dependencies(submap(insn.expression, insn.id)))
+        for arg in kernel_args:
+            if isinstance(arg, str) and arg != "....":
+                new_kernel_args.append(self.make_new_arg(arg))
+            else:
+                new_kernel_args.append(arg)
 
-    from loopy.kernel.data import ArrayBase
-    for arg in kernel_args:
-        if isinstance(arg, ArrayBase):
-            if isinstance(arg.shape, tuple):
-                all_names.update(get_dependencies(arg.shape))
+        return new_kernel_args
 
-    all_params = set()
-    for dom in domains:
-        all_params.update(dom.get_var_names(dim_type.param))
-    all_params = all_params - all_inames
+    def guess_kernel_args_if_requested(self, kernel_args):
+        # Ellipsis is syntactically allowed in Py3.
+        if "..." not in kernel_args and Ellipsis not in kernel_args:
+            return kernel_args
 
-    new_arg_names = (all_names | all_params) - not_new_arg_names
+        kernel_args = [arg for arg in kernel_args
+                if arg is not Ellipsis and arg != "..."]
 
-    # }}}
+        # {{{ find names that are *not* arguments
 
-    def find_index_rank(name):
-        irf = IndexRankFinder(name)
+        temp_var_names = set(self.temporary_variables.iterkeys())
 
-        for insn in instructions:
-            insn.with_transformed_expressions(
-                    lambda expr: irf(submap(expr, insn.id)))
+        for insn in self.instructions:
+            if isinstance(insn, ExpressionInstruction):
+                if insn.temp_var_type is not None:
+                    (assignee_var_name, _), = insn.assignees_and_indices()
+                    temp_var_names.add(assignee_var_name)
 
-        if not irf.index_ranks:
-            return 0
-        else:
-            from pytools import single_valued
-            return single_valued(irf.index_ranks)
+        # }}}
 
-    from loopy.kernel.data import ValueArg, GlobalArg
-    import loopy as lp
-    for arg_name in sorted(new_arg_names):
-        if arg_name in all_params:
-            kernel_args.append(ValueArg(arg_name))
-            continue
+        # {{{ find existing and new arg names
 
-        if arg_name in all_written_names:
-            # It's not a temp var, and thereby not a domain parameter--the only
-            # other writable type of variable is an argument.
+        existing_arg_names = set()
+        for arg in kernel_args:
+            existing_arg_names.add(arg.name)
 
-            kernel_args.append(
-                    GlobalArg(arg_name, shape=lp.auto, offset=default_offset))
-            continue
+        not_new_arg_names = existing_arg_names | temp_var_names | self.all_inames
 
-        irank = find_index_rank(arg_name)
-        if irank == 0:
-            # read-only, no indices
-            kernel_args.append(ValueArg(arg_name))
-        else:
-            kernel_args.append(
-                    GlobalArg(arg_name, shape=lp.auto, offset=default_offset))
+        from loopy.kernel.data import ArrayBase
+        from loopy.symbolic import get_dependencies
+        for arg in kernel_args:
+            if isinstance(arg, ArrayBase):
+                if isinstance(arg.shape, tuple):
+                    self.all_names.update(
+                            get_dependencies(arg.shape))
 
-    return kernel_args
+        new_arg_names = (self.all_names | self.all_params) - not_new_arg_names
+
+        # }}}
+
+        for arg_name in sorted(new_arg_names):
+            kernel_args.append(self.make_new_arg(arg_name))
+
+        return kernel_args
 
 # }}}
 
@@ -978,6 +1004,9 @@ def make_kernel(device, domains, instructions, kernel_data=["..."], **kwargs):
 
     from loopy.kernel.data import TemporaryVariable, ArrayBase
 
+    if isinstance(kernel_data, str):
+        kernel_data = kernel_data.split(",")
+
     kernel_args = []
     temporary_variables = {}
     for dat in kernel_data:
@@ -1038,10 +1067,13 @@ def make_kernel(device, domains, instructions, kernel_data=["..."], **kwargs):
 
     domains = parse_domains(isl_context, domains, defines)
 
-    kernel_args = guess_kernel_args_if_requested(domains, instructions,
-            temporary_variables, substitutions, kernel_args,
+    arg_guesser = ArgumentGuesser(domains, instructions,
+            temporary_variables, substitutions,
             default_offset)
 
+    kernel_args = arg_guesser.convert_names_to_full_args(kernel_args)
+    kernel_args = arg_guesser.guess_kernel_args_if_requested(kernel_args)
+
     from loopy.kernel import LoopKernel
     knl = LoopKernel(device, domains, instructions, kernel_args,
             temporary_variables=temporary_variables,
-- 
GitLab