From 1af94f8417f51b71f634d9c17482ef464862b4d9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 9 Nov 2011 10:47:51 -0500
Subject: [PATCH] An initial attempt at dealing with equality constraints.

---
 MEMO                      |  8 ++++----
 loopy/codegen/__init__.py | 10 ++++------
 loopy/codegen/loop.py     |  5 ++++-
 test/test_loopy.py        | 25 +++++++++++++++++++++++++
 4 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/MEMO b/MEMO
index d80e6519f..cd9014742 100644
--- a/MEMO
+++ b/MEMO
@@ -43,9 +43,6 @@ To-do
 
 - Fix all tests
 
-- Deal with equality constraints.
-  (These arise, e.g., when partitioning a loop of length 16 into 16s.)
-
 Future ideas
 ^^^^^^^^^^^^
 
@@ -86,7 +83,10 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
-- dim_max caching
+- Deal with equality constraints.
+  (These arise, e.g., when partitioning a loop of length 16 into 16s.)
+
+- dim_{min,max} caching
 
 - Exhaust the search for a no-boost solution first, before looking
   for a schedule with boosts.
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index a4dbafe64..62be45ce0 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -30,12 +30,7 @@ class GeneratedCode(Record):
     __slots__ = ["ast", "implemented_domains"]
 
 def gen_code_block(elements):
-    """
-    :param is_alternatives: a :class:`bool` indicating that
-        only one of the *elements* will effectively be executed.
-    """
-
-    from cgen import Block, Comment, Line
+    from cgen import Block, Comment, Line, Initializer
 
     block_els = []
     implemented_domains = {}
@@ -50,6 +45,9 @@ def gen_code_block(elements):
             else:
                 block_els.append(el.ast)
 
+        elif isinstance(el, Initializer):
+            block_els.append(el)
+
         elif isinstance(el, Comment):
             block_els.append(el)
 
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index c1bc4b091..0beaf47db 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -16,7 +16,10 @@ def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
                     | frozenset(get_defined_inames(kernel, sched_index+1)),
                     allow_parameters=True)
 
-    assert not equality_constraints_orig
+    lower_constraints_orig.extend(equality_constraints_orig)
+    upper_constraints_orig.extend(equality_constraints_orig)
+    #assert not equality_constraints_orig
+
     from loopy.codegen.bounds import pick_simple_constraint
     lb_cns_orig = pick_simple_constraint(lower_constraints_orig, iname)
     ub_cns_orig = pick_simple_constraint(upper_constraints_orig, iname)
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 71e93909b..cc387b017 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -151,6 +151,31 @@ def test_stencil(ctx_factory):
 
 
 
+def test_eq_constraint(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "{[i,j]: 0<= i,j < 32}",
+            [
+                "a[i] = b[i]"
+                ],
+            [
+                lp.ArrayArg("a", np.float32, shape=(1000,)),
+                lp.ArrayArg("b", np.float32, shape=(1000,))
+                ])
+
+    knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0")
+    knl = lp.split_dimension(knl, "i_inner", 16, outer_tag=None, inner_tag="l.0")
+
+    kernel_gen = lp.generate_loop_schedules(knl)
+    kernel_gen = lp.check_kernels(kernel_gen)
+
+    for knl in kernel_gen:
+        print lp.generate_code(knl)
+
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab