From d0f46221e2249d7894aed2d5e7ab21e84c419eac Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 22 Sep 2012 23:35:02 -0400
Subject: [PATCH] Retain all necessary parameters in determining whether
 accesses are in-footprint.

Also catch branch-in-domain-tree if necessary.
---
 loopy/cse.py       | 13 +++++++++++--
 loopy/kernel.py    |  4 +++-
 test/test_nbody.py |  7 +++----
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/loopy/cse.py b/loopy/cse.py
index 809748b0b..c468c7f9d 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -160,13 +160,22 @@ def build_global_storage_to_sweep_map(kernel, invocation_descriptors,
 
     for invdesc in invocation_descriptors:
         if not invdesc.expands_footprint:
-            arg_inames = set()
+            arg_inames = (
+                    set(global_s2s_par_dom.get_var_names(dim_type.param))
+                    & kernel.all_inames())
 
             for arg in invdesc.args:
                 arg_inames.update(get_dependencies(arg))
             arg_inames = frozenset(arg_inames)
 
-            usage_domain = kernel.get_inames_domain(arg_inames)
+            from loopy.kernel import CannotBranchDomainTree
+            try:
+                usage_domain = kernel.get_inames_domain(arg_inames)
+            except CannotBranchDomainTree:
+                # and that's the end of that.
+                invdesc.is_in_footprint = False
+                continue
+
             for i in xrange(usage_domain.dim(dim_type.set)):
                 iname = usage_domain.get_dim_name(dim_type.set, i)
                 if iname in sweep_inames:
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 15911707e..45f9245a9 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -12,6 +12,8 @@ import re
 
 
 
+class CannotBranchDomainTree(RuntimeError):
+    pass
 
 # {{{ index tags
 
@@ -1278,7 +1280,7 @@ class LoopKernel(Record):
 
             all_parents = set(ppd[home_domain_index])
             if not domain_indices <= all_parents:
-                raise RuntimeError("iname set '%s' requires "
+                raise CannotBranchDomainTree("iname set '%s' requires "
                         "branch in domain tree (when adding '%s')"
                         % (", ".join(inames), iname))
 
diff --git a/test/test_nbody.py b/test/test_nbody.py
index 931b460a5..da7c77b01 100644
--- a/test/test_nbody.py
+++ b/test/test_nbody.py
@@ -48,16 +48,15 @@ def test_nbody(ctx_factory):
         knl = lp.split_iname(knl, "i", 256,
                 outer_tag="g.0", inner_tag="l.0", slabs=(0,1))
         knl = lp.split_iname(knl, "j", 256, slabs=(0,1))
-        knl = lp.add_prefetch(knl, "x[i,k]", ["k"], default_tag=None)
         knl = lp.add_prefetch(knl, "x[j,k]", ["j_inner", "k"],
                 ["x_fetch_j", "x_fetch_k"])
+        knl = lp.add_prefetch(knl, "x[i,k]", ["k"], default_tag=None)
         knl = lp.tag_inames(knl, dict(x_fetch_k="unr"))
         return knl, ["j_outer", "j_inner"]
 
     n = 3000
 
-    for variant in [ variant_cpu]:
-    #for variant in [variant_1, variant_cpu, variant_gpu]:
+    for variant in [variant_1, variant_cpu, variant_gpu]:
         variant_knl, loop_prio = variant(knl)
         kernel_gen = lp.generate_loop_schedules(variant_knl,
                 loop_priority=loop_prio)
@@ -65,7 +64,7 @@ def test_nbody(ctx_factory):
 
         lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
                 op_count=[n**2*1e-6], op_label=["M particle pairs"],
-                parameters={"N": n}, print_ref_code=True)
+                parameters={"N": n})
 
 
 
-- 
GitLab