From 69ffdedfd09c800f3f47a060ec946983bbd6301e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 22 Sep 2012 22:26:26 -0400
Subject: [PATCH] FIX: Slab decomposition should not influence grid sizes.

---
 MEMO                  |  2 +-
 loopy/codegen/loop.py |  4 +++-
 loopy/kernel.py       | 11 ++++++++++-
 3 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/MEMO b/MEMO
index dee9d32f8..bc1e1d6d8 100644
--- a/MEMO
+++ b/MEMO
@@ -107,7 +107,7 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
-- test divisibility constraints
+- Test divisibility constraints
 
 - Test join_inames
 
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index bb774492c..9c5159111 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -123,7 +123,9 @@ def intersect_kernel_with_slab(kernel, slab, iname):
     home_domain = kernel.domains[hdi]
     new_domains = kernel.domains[:]
     new_domains[hdi] = home_domain & isl.align_spaces(slab, home_domain)
-    return kernel.copy(domains=new_domains)
+
+    return kernel.copy(domains=new_domains,
+            get_grid_sizes=kernel.get_grid_sizes)
 
 
 # {{{ hw-parallel loop
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 01ba6e236..15911707e 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -724,7 +724,12 @@ class LoopKernel(Record):
             cache_manager=None,
             iname_to_tag_requests=None,
             index_dtype=np.int32,
-            isl_context=None):
+            isl_context=None,
+
+            # When kernels get intersected in slab decomposition,
+            # their grid sizes shouldn't change. This provides
+            # a way to forward sub-kernel grid size requests.
+            get_grid_sizes=None):
         """
         :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl.
             Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}"
@@ -995,6 +1000,10 @@ class LoopKernel(Record):
         if np.iinfo(index_dtype).min >= 0:
             raise TypeError("index_dtype must be signed")
 
+        if get_grid_sizes is not None:
+            # overwrites method down below
+            self.get_grid_sizes = get_grid_sizes
+
         Record.__init__(self,
                 device=device, domains=domains,
                 instructions=parsed_instructions,
-- 
GitLab