From 44305ffae2d95fdb3fd1e118f71961a95e94c0a3 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 24 Sep 2011 19:37:18 -0400
Subject: [PATCH] Many fixes for late reduction realization.

---
 MEMO                |  8 ++--
 loopy/__init__.py   | 90 ++++++++++++++++++++++++---------------------
 loopy/kernel.py     | 23 ------------
 loopy/schedule.py   | 14 ++++---
 loopy/symbolic.py   |  7 +++-
 test/test_matmul.py | 13 ++-----
 6 files changed, 70 insertions(+), 85 deletions(-)

diff --git a/MEMO b/MEMO
index 22fa41ca8..942ac822c 100644
--- a/MEMO
+++ b/MEMO
@@ -55,13 +55,15 @@ Things to consider
 
 - Implement get_problems()
 
-- CSE iname duplication might be unnecessary?
-  (don't think so: It might be desired to do a full fetch before a mxm k loop
-  even if that requires going iterative.)
+- FIXME: Deal with insns losing a seq iname dep in a CSE realization
 
 Dealt with
 ^^^^^^^^^^
 
+- CSE iname duplication might be unnecessary?
+  (don't think so: It might be desired to do a full fetch before a mxm k loop
+  even if that requires going iterative.)
+
 - Reduction needs to know a neutral element
 
 - Types of reduction variables?
diff --git a/loopy/__init__.py b/loopy/__init__.py
index d9cedd753..6c7136304 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -37,13 +37,13 @@ from loopy.compiled import CompiledKernel, drive_timing_run
 
 # {{{ user-facing kernel manipulation functionality
 
-def split_dimension(kernel, name, inner_length, padded_length=None,
-        outer_name=None, inner_name=None,
+def split_dimension(kernel, iname, inner_length, padded_length=None,
+        outer_iname=None, inner_iname=None,
         outer_tag=None, inner_tag=None,
         outer_slab_increments=(0, -1), no_slabs=None):
 
-    if name not in kernel.all_inames():
-        raise ValueError("cannot split loop for unknown variable '%s'" % name)
+    if iname not in kernel.all_inames():
+        raise ValueError("cannot split loop for unknown variable '%s'" % iname)
 
     if no_slabs:
         outer_slab_increments = (0, 0)
@@ -51,29 +51,29 @@ def split_dimension(kernel, name, inner_length, padded_length=None,
     if padded_length is not None:
         inner_tag = inner_tag.copy(forced_length=padded_length)
 
-    if outer_name is None:
-        outer_name = name+"_outer"
-    if inner_name is None:
-        inner_name = name+"_inner"
+    if outer_iname is None:
+        outer_iname = iname+"_outer"
+    if inner_iname is None:
+        inner_iname = iname+"_inner"
 
     outer_var_nr = kernel.space.dim(dim_type.set)
     inner_var_nr = kernel.space.dim(dim_type.set)+1
 
     def process_set(s):
         s = s.add_dims(dim_type.set, 2)
-        s.set_dim_name(dim_type.set, outer_var_nr, outer_name)
-        s.set_dim_name(dim_type.set, inner_var_nr, inner_name)
+        s.set_dim_name(dim_type.set, outer_var_nr, outer_iname)
+        s.set_dim_name(dim_type.set, inner_var_nr, inner_iname)
 
         from loopy.isl import make_slab
 
         space = s.get_space()
         inner_constraint_set = (
-                make_slab(space, inner_name, 0, inner_length)
+                make_slab(space, inner_iname, 0, inner_length)
                 # name = inner + length*outer
                 .add_constraint(isl.Constraint.eq_from_names(
-                    space, {name:1, inner_name: -1, outer_name:-inner_length})))
+                    space, {iname:1, inner_iname: -1, outer_iname:-inner_length})))
 
-        name_dim_type, name_idx = space.get_var_dict()[name]
+        name_dim_type, name_idx = space.get_var_dict()[iname]
         return (s
                 .intersect(inner_constraint_set)
                 .project_out(name_dim_type, name_idx, 1))
@@ -82,37 +82,46 @@ def split_dimension(kernel, name, inner_length, padded_length=None,
     new_assumptions = process_set(kernel.assumptions)
 
     from pymbolic import var
-    inner = var(inner_name)
-    outer = var(outer_name)
+    inner = var(inner_iname)
+    outer = var(outer_iname)
     new_loop_index = inner + outer*inner_length
 
-    # {{{ look for reduction loops to split, split seq. tags
+    # {{{ actually modify instructions
 
     from loopy.symbolic import ReductionLoopSplitter
 
-    rls = ReductionLoopSplitter(name, outer_name, inner_name)
+    rls = ReductionLoopSplitter(iname, outer_iname, inner_iname)
     new_insns = []
     for insn in kernel.instructions:
-        insn = insn.copy(expression=rls(insn.expression))
-        old_iname_tag = insn.iname_to_tag.get(name)
+        subst_map = {var(iname): new_loop_index}
 
+        from loopy.symbolic import SubstitutionMapper
+        subst_mapper = SubstitutionMapper(subst_map.get)
+
+        new_expr = subst_mapper(rls(insn.expression))
+
+        old_iname_tag = insn.iname_to_tag.get(iname)
         new_iname_to_tag = insn.iname_to_tag.copy()
-        tagged_ok = False
-        from loopy.kernel import SequentialTag
-        if isinstance(old_iname_tag, SequentialTag):
-            tagged_ok = True
-            del new_iname_to_tag[name]
-            new_iname_to_tag[outer_name] = old_iname_tag
-            new_iname_to_tag[inner_name] = old_iname_tag
-
-        if name in insn.forced_iname_deps:
+
+        from loopy.kernel import UniqueTag
+        if not isinstance(old_iname_tag, UniqueTag):
+            new_iname_to_tag.pop(iname, None)
+            new_iname_to_tag[outer_iname] = old_iname_tag
+            new_iname_to_tag[inner_iname] = old_iname_tag
+        else:
+            raise RuntimeError("cannot split already unique-tagged iname '%s'"
+                    % iname)
+
+        if iname in insn.forced_iname_deps:
             new_forced_iname_deps = insn.forced_iname_deps[:]
-            new_forced_iname_deps.remove(name)
-            new_forced_iname_deps.extend([outer_name, inner_name])
+            new_forced_iname_deps.remove(iname)
+            new_forced_iname_deps.extend([outer_iname, inner_iname])
         else:
             new_forced_iname_deps = insn.forced_iname_deps
 
-        insn = insn.substitute(name, new_loop_index, tagged_ok=tagged_ok).copy(
+        insn = insn.copy(
+                assignee=subst_mapper(insn.assignee),
+                expression=new_expr,
                 iname_to_tag=new_iname_to_tag,
                 forced_iname_deps=new_forced_iname_deps
                 )
@@ -122,7 +131,7 @@ def split_dimension(kernel, name, inner_length, padded_length=None,
     # }}}
 
     iname_slab_increments = kernel.iname_slab_increments.copy()
-    iname_slab_increments[outer_name] = outer_slab_increments
+    iname_slab_increments[outer_iname] = outer_slab_increments
     result = (kernel
             .copy(domain=new_domain,
                 assumptions=new_assumptions,
@@ -130,7 +139,7 @@ def split_dimension(kernel, name, inner_length, padded_length=None,
                 name_to_dim=None,
                 instructions=new_insns))
 
-    return tag_dimensions(result, {outer_name: outer_tag, inner_name: inner_tag})
+    return tag_dimensions(result, {outer_iname: outer_tag, inner_iname: inner_tag})
 
 
 
@@ -311,13 +320,7 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
         new_expr = cse_cb_mapper(insn.expression)
 
         if was_empty and cse_result_insns:
-            new_iname_to_tag = insn.iname_to_tag.copy()
-            new_iname_to_tag.update(dup_iname_to_tag)
-
-            new_insns.append(
-                    insn.copy(
-                        expression=new_expr,
-                        iname_to_tag=new_iname_to_tag))
+            new_insns.append(insn.copy(expression=new_expr))
         else:
             new_insns.append(insn)
 
@@ -401,7 +404,7 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
 
 
 
-def realize_reduction(kernel, loop_iname, reduction_tag=None):
+def realize_reduction(kernel, inames=None, reduction_tag=None):
     new_insns = []
     new_temporary_variables = kernel.temporary_variables[:]
 
@@ -411,9 +414,12 @@ def realize_reduction(kernel, loop_iname, reduction_tag=None):
         if reduction_tag is not None and expr.tag != reduction_tag:
             return
 
+        if inames is not None and set(inames) != set(expr.inames):
+            return
+
         from pymbolic import var
 
-        target_var_name = kernel.make_unique_var_name("red_"+loop_iname,
+        target_var_name = kernel.make_unique_var_name("red",
                 extra_used_vars=set(tv.name for tv in new_temporary_variables))
         target_var = var(target_var_name)
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 5df71831a..c3a0ca0ec 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -284,21 +284,6 @@ class Instruction(Record):
 
         return result
 
-    def substitute(self, old_var, new_expr, tagged_ok=False):
-        from loopy.symbolic import SubstitutionMapper
-
-        prev_tag = self.iname_to_tag.get(old_var)
-        if prev_tag is not None and not tagged_ok:
-            raise RuntimeError("cannot substitute already tagged variable '%s'"
-                    % old_var)
-
-        subst_map = {var(old_var): new_expr}
-
-        subst_mapper = SubstitutionMapper(subst_map.get)
-        return self.copy(
-                assignee=subst_mapper(self.assignee),
-                expression=subst_mapper(self.expression))
-
     def __str__(self):
         loop_descrs = []
         for iname in sorted(self.all_inames()):
@@ -633,14 +618,6 @@ class LoopKernel(Record):
         return sum(lv.nbytes for lv in self.temporary_variables
                 if lv.is_local)
 
-    def substitute(self, old_var, new_expr):
-        if self.schedule is not None:
-            raise RuntimeError("cannot substitute-schedule already generated")
-
-        return self.copy(instructions=[
-            insn.substitute(old_var, new_expr)
-            for insn in self.instructions])
-
 # }}}
 
 
diff --git a/loopy/schedule.py b/loopy/schedule.py
index 08729bbe5..7a3cc5239 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -77,22 +77,24 @@ def generate_loop_schedules_internal(kernel, entered_loops=[]):
 
 
 def generate_loop_schedules(kernel):
-    # {{{ check that all CSEs and reductions are realized
+    from loopy import realize_reduction
+    kernel = realize_reduction(kernel)
 
-    from loopy.symbolic import CSECallbackMapper, ReductionCallbackMapper
+    # {{{ check that all CSEs
 
-    def map_reduction(expr, rec):
-        raise RuntimeError("all reductions must be realized before scheduling")
+    from loopy.symbolic import CSECallbackMapper
 
     def map_cse(expr, rec):
         raise RuntimeError("all CSEs must be realized before scheduling")
 
     for insn in kernel.instructions:
-        ReductionCallbackMapper(map_reduction)(insn.expression)
         CSECallbackMapper(map_cse)(insn.expression)
 
     # }}}
 
+    for i, insn_a in enumerate(kernel.instructions):
+        print i, insn_a
+
     kernel = fix_grid_sizes(kernel)
 
     if 0:
@@ -101,6 +103,8 @@ def generate_loop_schedules(kernel):
             print "%s: %s" % (k, ",".join(v))
         1/0
 
+    kernel = find_automatic_dependencies(kernel)
+
     #grid_size, group_size = find_known_grid_and_group_sizes(kernel)
 
     #kernel = assign_grid_and_group_indices(kernel)
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 816943f3a..f44da7b03 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -148,10 +148,10 @@ class ReductionLoopSplitter(IdentityMapper):
 
     def map_reduction(self, expr):
         if self.old_iname in expr.inames:
-            new_inames = expr.inames[:]
+            new_inames = list(expr.inames)
             new_inames.remove(self.old_iname)
             new_inames.extend([self.outer_iname, self.inner_iname])
-            return Reduction(expr.operation, new_inames,
+            return Reduction(expr.operation, tuple(new_inames),
                         expr.expr, expr.tag)
         else:
             return IdentityMapper.map_reduction(self, expr)
@@ -490,6 +490,9 @@ class IndexVariableFinder(CombineMapper):
                 raise RuntimeError("index variable not understood: %s" % idx_var)
         return result
 
+    def map_reduction(self, expr):
+        return set(expr.inames) | self.rec(expr.expr)
+
 # }}}
 
 
diff --git a/test/test_matmul.py b/test/test_matmul.py
index 5e5564a14..4c51518b2 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -217,21 +217,14 @@ def test_plain_matrix_mul_new_ui(ctx_factory):
             outer_tag="g.0", inner_tag="l.1", no_slabs=True)
     knl = lp.split_dimension(knl, "j", 16,
             outer_tag="g.1", inner_tag="l.0", no_slabs=True)
-    for insn in knl.instructions:
-        print insn
-    print
-    knl = lp.realize_reduction(knl, "k")
-    for insn in knl.instructions:
-        print insn
-    print
     knl = lp.split_dimension(knl, "k", 16, no_slabs=True)
 
     knl = lp.realize_cse(knl, "lhsmat", dtype, ["k_inner", "i_inner"])
     knl = lp.realize_cse(knl, "rhsmat", dtype, ["j_inner", "k_inner"])
 
-    for insn in knl.instructions:
-        print insn
-
+    #print
+    #for insn in knl.instructions:
+        #print insn
     #assert lp.get_problems(knl, {})[0] <= 2
 
     kernel_gen = lp.generate_loop_schedules(knl)
-- 
GitLab