From b5b6c88e07a8d6558fa092169a5b3fbfb6ba1f4b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 7 May 2013 11:20:18 -0400
Subject: [PATCH] Support substitution rules in arg guessing.

---
 loopy/kernel/creation.py | 25 +++++++++++++++++--------
 1 file changed, 17 insertions(+), 8 deletions(-)

diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 29c0a2f76..468f4528a 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -408,13 +408,16 @@ class IndexRankFinder(WalkMapper):
             else:
                 self.index_ranks.append(len(expr.index))
 
-def guess_kernel_args_if_requested(domains, instructions, temporary_variables, kernel_args):
+def guess_kernel_args_if_requested(domains, instructions, temporary_variables, subst_rules, kernel_args):
     if "..." not in kernel_args:
         return kernel_args
 
     kernel_args = kernel_args[:]
     kernel_args.remove("...")
 
+    from loopy.symbolic import SubstitutionRuleExpander
+    submap = SubstitutionRuleExpander(subst_rules)
+
     # {{{ find names that are *not* arguments
 
     all_inames = set()
@@ -442,8 +445,8 @@ def guess_kernel_args_if_requested(domains, instructions, temporary_variables, k
     from loopy.symbolic import get_dependencies
     for insn in instructions:
         all_written_names.add(insn.get_assignee_var_name())
-        all_names.update(get_dependencies(insn.expression))
-        all_names.update(get_dependencies(insn.assignee))
+        all_names.update(get_dependencies(submap(insn.assignee, insn.id)))
+        all_names.update(get_dependencies(submap(insn.expression, insn.id)))
 
     all_params = set()
     for dom in domains:
@@ -456,9 +459,10 @@ def guess_kernel_args_if_requested(domains, instructions, temporary_variables, k
 
     def find_index_rank(name):
         irf = IndexRankFinder(name)
+
         for insn in instructions:
-            irf(insn.expression)
-            irf(insn.assignee)
+            irf(submap(insn.expression, insn.id))
+            irf(submap(insn.assignee, insn.id))
 
         if not irf.index_ranks:
             return 0
@@ -778,6 +782,10 @@ def guess_arg_shape_if_requested(kernel, default_order):
     new_args = []
 
     from loopy.kernel.data import ShapedArg, auto_shape, auto_strides
+    from loopy.symbolic import SubstitutionRuleExpander
+
+    submap = SubstitutionRuleExpander(kernel.substitutions,
+            kernel.get_var_name_generator())
 
     for arg in kernel.args:
         if isinstance(arg, ShapedArg) and (
@@ -786,8 +794,9 @@ def guess_arg_shape_if_requested(kernel, default_order):
 
             for insn in kernel.instructions:
                 domain = kernel.get_inames_domain(kernel.insn_inames(insn))
-                armap(insn.assignee, domain)
-                armap(insn.expression, domain)
+                armap(submap(insn.assignee, insn.id), domain)
+                armap(submap(insn.expression, insn.id), domain)
+
 
             if armap.access_range is None:
                 # no subscripts found, let's call it a scalar
@@ -895,7 +904,7 @@ def make_kernel(device, domains, instructions, kernel_args=["..."], **kwargs):
     domains = parse_domains(isl_context, domains, defines)
 
     kernel_args = guess_kernel_args_if_requested(domains, instructions,
-            kwargs.get("temporary_variables", {}), kernel_args)
+            kwargs.get("temporary_variables", {}), substitutions, kernel_args)
 
     from loopy.kernel import LoopKernel
     knl = LoopKernel(device, domains, instructions, kernel_args, **kwargs)
-- 
GitLab