From e20729c5c2a2d9437dfb8c23d74d4188d7ca0f7f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 12 Jun 2014 08:56:51 +0100
Subject: [PATCH] Implement loopy.assume

---
 doc/reference.rst        |  4 ++--
 loopy/__init__.py        | 22 ++++++++++++++++++++++
 loopy/kernel/__init__.py | 22 +++++++++++++---------
 test/test_loopy.py       | 22 ++++++++++++++++++++++
 4 files changed, 59 insertions(+), 11 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index f7347c118..4892b1ad7 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -362,6 +362,8 @@ Dealing with Parameters
 
 .. autofunction:: fix_parameters
 
+.. autofunction:: assume
+
 Dealing with Substitution Rules
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -376,8 +378,6 @@ Caching, Precomputation and Prefetching
 
 .. autofunction:: add_prefetch
 
-    Uses :func:`extract_subst` and :func:`precompute`.
-
 Influencing data access
 ^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index d16d01597..5d9ad93e7 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -908,6 +908,8 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         directly by putting an index expression into *var_name*. Substitutions
         such as those occurring in dimension splits are recorded and also
         applied to these indices.
+
+    This function combines :func:`extract_subst` and :func:`precompute`.
     """
 
     # {{{ fish indexing out of var_name and into footprint_subscripts
@@ -1259,6 +1261,26 @@ def fix_parameters(kernel, **value_dict):
 # }}}
 
 
+# {{{ assume
+
+def assume(kernel, assumptions):
+    if isinstance(assumptions, str):
+        assumptions_set_str = "[%s] -> { : %s}" \
+                % (",".join(s for s in kernel.outer_params()),
+                    assumptions)
+        assumptions = isl.BasicSet.read_from_str(kernel.domains[0].get_ctx(),
+                assumptions_set_str)
+
+    if not isinstance(assumptions, isl.BasicSet):
+        raise TypeError("'assumptions' must be a BasicSet or a string")
+
+    old_assumptions, new_assumptions = isl.align_two(kernel.assumptions, assumptions)
+
+    return kernel.copy(assumptions=old_assumptions.params() & new_assumptions.params())
+
+# }}}
+
+
 # {{{ set_options
 
 def set_options(kernel, *args, **kwargs):
diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index 4fa3e05ca..0c9e4d3d1 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -225,16 +225,8 @@ class LoopKernel(RecordWithoutPickling):
             assumptions = isl.BasicSet.universe(assumptions_space)
 
         elif isinstance(assumptions, str):
-            all_inames = set()
-            all_params = set()
-            for dom in domains:
-                all_inames.update(dom.get_var_names(dim_type.set))
-                all_params.update(dom.get_var_names(dim_type.param))
-
-            domain_parameters = all_params-all_inames
-
             assumptions_set_str = "[%s] -> { : %s}" \
-                    % (",".join(s for s in domain_parameters),
+                    % (",".join(s for s in self.outer_params(domains)),
                         assumptions)
             assumptions = isl.BasicSet.read_from_str(domains[0].get_ctx(),
                     assumptions_set_str)
@@ -591,6 +583,18 @@ class LoopKernel(RecordWithoutPickling):
 
         return frozenset(result)
 
+    def outer_params(self, domains=None):
+        if domains is None:
+            domains = self.domains
+
+        all_inames = set()
+        all_params = set()
+        for dom in self.domains:
+            all_inames.update(dom.get_var_names(dim_type.set))
+            all_params.update(dom.get_var_names(dim_type.param))
+
+        return all_params-all_inames
+
     @memoize_method
     def all_insn_inames(self):
         """Return a mapping from instruction ids to inames inside which
diff --git a/test/test_loopy.py b/test/test_loopy.py
index f9e4eee78..d17ec1c0c 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -127,6 +127,28 @@ def test_sized_and_complex_literals(ctx_factory):
     lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5))
 
 
+def test_assume(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(
+            "{[i]: 0<=i<n}",
+            "a[i] = a[i] + 1",
+            [lp.GlobalArg("a", np.float32, shape="n"), "..."])
+
+    knl = lp.split_iname(knl, "i", 16)
+    knl = lp.set_loop_priority(knl, "i_outer,i_inner")
+    knl = lp.assume(knl, "n mod 16 = 0")
+    knl = lp.assume(knl, "n > 10")
+    knl = lp.preprocess_kernel(knl, ctx.devices[0])
+    kernel_gen = lp.generate_loop_schedules(knl)
+
+    for gen_knl in kernel_gen:
+        print gen_knl
+        compiled = lp.CompiledKernel(ctx, gen_knl)
+        print compiled.get_code()
+        assert "if" not in compiled.get_code()
+
+
 def test_simple_side_effect(ctx_factory):
     ctx = ctx_factory()
 
-- 
GitLab