From 9058893b96628b2bc9c10988c082730d497dd4f3 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 16 Oct 2011 00:10:50 -0400
Subject: [PATCH] Be smarter about (split and such) automatic local dimensions.

---
 loopy/codegen/__init__.py |  13 ++--
 loopy/codegen/loop.py     |  11 ++-
 loopy/kernel.py           |  22 +++---
 loopy/schedule.py         | 147 +++++++++++++++++++++++---------------
 4 files changed, 115 insertions(+), 78 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index fee274b46..1718e8b8f 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -112,10 +112,15 @@ class CodeGenerationState(object):
                 self.implemented_domain.intersect(set),
                 self.c_code_mapper)
 
-    def fix(self, iname, aff):
-        dt, pos = aff.get_space().get_var_dict()[iname]
-        iname_plus_lb_aff = aff.add_coefficient(
-                dt, pos, -1)
+    def fix(self, iname, aff, space):
+        dt, pos = space.get_var_dict()[iname]
+        assert dt == isl.dim_type.set
+
+        zero = isl.Aff.zero_on_domain(space)
+
+        from islpy import align_spaces
+        iname_plus_lb_aff = align_spaces(aff, zero).add_coefficient(
+                isl.dim_type.in_, pos, -1)
 
         from loopy.symbolic import pw_aff_to_expr
         cns = isl.Constraint.equality_from_aff(iname_plus_lb_aff)
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 14d6c5862..f8df9ddf4 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -79,13 +79,11 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
 
 # }}}
 
-# {{{ unrolled/ILP loops
+# {{{ unrolled loops
 
 def generate_unroll_loop(kernel, sched_index, codegen_state):
     from loopy.isl_helpers import block_shift_constraint
 
-    from cgen import (POD, Line)
-
     ccm = codegen_state.c_code_mapper
     space = kernel.space
     iname = kernel.schedule[sched_index].iname
@@ -98,7 +96,8 @@ def generate_unroll_loop(kernel, sched_index, codegen_state):
     from loopy.isl_helpers import static_max_of_pw_aff
     from loopy.symbolic import pw_aff_to_expr
 
-    length = int(pw_aff_to_expr(static_max_of_pw_aff(bounds.length)))
+    length = int(pw_aff_to_expr(
+        static_max_of_pw_aff(bounds.size, constants_only=True)))
     lower_bound_pw_aff_pieces = bounds.lower_bound_pw_aff.coalesce().get_pieces()
 
     if len(lower_bound_pw_aff_pieces) > 1:
@@ -116,11 +115,11 @@ def generate_unroll_loop(kernel, sched_index, codegen_state):
 
     from loopy.kernel import UnrollTag
     if isinstance(tag, UnrollTag):
-        result = [POD(np.int32, iname), Line()]
+        result = []
 
         for i in range(length):
             idx_aff = lower_bound_aff + i
-            new_codegen_state = codegen_state.fix(iname, idx_aff)
+            new_codegen_state = codegen_state.fix(iname, idx_aff, kernel.space)
             result.append(
                     build_loop_nest(kernel, sched_index+1, new_codegen_state))
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index dcf14172d..699b159ff 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -67,11 +67,7 @@ class AutoLocalIndexTagBase(LocalIndexTagBase):
 
 class AutoFitLocalIndexTag(AutoLocalIndexTagBase):
     def __str__(self):
-        return "l.fit"
-
-class AutoPickLocalIndexTag(AutoLocalIndexTagBase):
-    def __str__(self):
-        return "l.pick"
+        return "l.auto"
 
 class IlpTag(ParallelTag):
     def __str__(self):
@@ -99,9 +95,7 @@ def parse_tag(tag):
         return GroupIndexTag(int(tag[2:]))
     elif tag.startswith("l."):
         axis = tag[2:]
-        if axis == "pick":
-            return AutoPickLocalIndexTag()
-        elif axis == "fit":
+        if axis == "auto":
             return AutoFitLocalIndexTag()
         else:
             return LocalIndexTag(int(axis))
@@ -605,6 +599,14 @@ class LoopKernel(Record):
                 upper_bound_pw_aff=upper_bound_pw_aff,
                 size=size)
 
+    @memoize_method
+    def get_constant_iname_length(self, iname):
+        from loopy.isl_helpers import static_max_of_pw_aff
+        from loopy.symbolic import aff_to_expr
+        return int(aff_to_expr(static_max_of_pw_aff(
+                self.get_iname_bounds(iname).size,
+                constants_only=True)))
+
     @memoize_method
     def get_grid_sizes(self, ignore_auto=False):
         all_inames_by_insns = set()
@@ -678,8 +680,8 @@ class LoopKernel(Record):
         return (to_dim_tuple(global_sizes, "global"),
                 to_dim_tuple(local_sizes, "local"))
 
-    def get_grid_sizes_as_exprs(self):
-        grid_size, group_size = self.get_grid_sizes()
+    def get_grid_sizes_as_exprs(self, ignore_auto=False):
+        grid_size, group_size = self.get_grid_sizes(ignore_auto=ignore_auto)
 
         def tup_to_exprs(tup):
             from loopy.symbolic import pw_aff_to_expr
diff --git a/loopy/schedule.py b/loopy/schedule.py
index e0c36afb5..de2cb928e 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -130,7 +130,7 @@ def check_for_unused_hw_axes(kernel):
 
 
 
-def check_double_use_of_hw_axes(kernel):
+def check_for_double_use_of_hw_axes(kernel):
     from loopy.kernel import UniqueTag
 
     for insn in kernel.instructions:
@@ -373,79 +373,109 @@ def guess_good_iname_for_axis_0(kernel, insn):
 
 
 
-def find_inadmissible_tag_keys(kernel, iname, iname_to_tag=None):
-    if iname_to_tag is None:
-        iname_to_tag = kernel.iname_to_tag
+def assign_automatic_axes(kernel, only_axis_0=True):
+    from loopy.kernel import (AutoLocalIndexTagBase, LocalIndexTag,
+            UnrollTag)
 
-    result = set()
+    global_size, local_size = kernel.get_grid_sizes_as_exprs(
+            ignore_auto=True)
 
-    from loopy.kernel import UniqueTag
+    def assign_axis(iname, axis=None):
+        print "assign", iname
+        desired_length = kernel.get_constant_iname_length(iname)
 
-    for insn in kernel.instructions:
-        if iname in insn.all_inames():
-            for insn_iname in insn.all_inames():
-                if insn_iname == iname:
-                    continue
+        if axis is None:
+            # {{{ find a suitable axis
 
-                tag = iname_to_tag.get(insn_iname)
-                if isinstance(tag, UniqueTag):
-                    result.add(tag.key)
+            # find already assigned local axes (to avoid them)
+            shorter_possible_axes = []
+            test_axis = 0
+            while True:
+                if test_axis >= len(local_size):
+                    break
+                if test_axis in assigned_local_axes:
+                    test_axis += 1
+                    continue
 
-    return result
+                if local_size[test_axis] < desired_length:
+                    shorter_possible_axes.append(test_axis)
+                    test_axis += 1
+                    continue
+                else:
+                    axis = test_axis
+                    break
 
+            # longest first
+            shorter_possible_axes.sort(key=lambda ax: local_size[ax])
 
+            if axis is None and shorter_possible_axes:
+                axis = shorter_possible_axes[0]
 
+            # }}}
 
-def assign_automatic_axes(kernel):
-    from loopy.kernel import (
-            TAG_AUTO_LOCAL_IDX, LocalIndexTag)
+        if axis is None:
+            new_tag = None
+        else:
+            new_tag = LocalIndexTag(axis)
+            print iname, desired_length, local_size[axis]
+            if desired_length > local_size[axis]:
+                from loopy import split_dimension
+                return assign_automatic_axes(
+                        split_dimension(kernel, iname, inner_length=local_size[axis],
+                            outer_tag=UnrollTag(), inner_tag=new_tag, no_slabs=True),
+                        only_axis_0=only_axis_0)
+
+        new_iname_to_tag = kernel.iname_to_tag.copy()
+        new_iname_to_tag[iname] = new_tag
+        return assign_automatic_axes(kernel.copy(iname_to_tag=new_iname_to_tag),
+                only_axis_0=only_axis_0)
 
-    new_iname_to_tag = kernel.iname_to_tag
+    for insn in kernel.instructions:
+        auto_axis_inames = [
+                iname
+                for iname in insn.all_inames()
+                if isinstance(kernel.iname_to_tag.get(iname),
+                    AutoLocalIndexTagBase)]
 
-    # first assign each insn's axis 0, then the rest
-    for only_axis_0 in [True, False]:
+        if not auto_axis_inames:
+            continue
 
-        for insn in kernel.instructions:
-            auto_axis_inames = [
-                    iname
-                    for iname in insn.all_inames()
-                    if isinstance(new_iname_to_tag.get(iname), TAG_AUTO_LOCAL_IDX)]
+        assigned_local_axes = set()
 
-            if not auto_axis_inames:
-                continue
-
-            local_assigned_axes = set()
+        for iname in insn.all_inames():
+            tag = kernel.iname_to_tag.get(iname)
+            if isinstance(tag, LocalIndexTag):
+                assigned_local_axes.add(tag.axis)
 
-            for iname in insn.all_inames():
-                tag = new_iname_to_tag.get(iname)
-                if isinstance(tag, LocalIndexTag):
-                    local_assigned_axes.add(tag.axis)
+        axis0_iname = guess_good_iname_for_axis_0(kernel, insn)
 
-            if 0 not in local_assigned_axes:
-                axis0_iname = guess_good_iname_for_axis_0(kernel, insn)
+        axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
+        ax0_tag = LocalIndexTag(0)
+        if (isinstance(axis0_iname_tag, AutoLocalIndexTagBase)
+                and 0 not in assigned_local_axes):
+            return assign_axis(axis0_iname, 0)
 
-                axis0_iname_tag = new_iname_to_tag.get(axis0_iname)
-                ax0_tag = LocalIndexTag(0)
-                if (isinstance(axis0_iname_tag, TAG_AUTO_LOCAL_IDX)
-                        and ax0_tag.key not in find_inadmissible_tag_keys(
-                            kernel, axis0_iname, new_iname_to_tag)):
-                    new_iname_to_tag[axis0_iname] = ax0_tag
-                    local_assigned_axes.add(0)
-                    auto_axis_inames.remove(axis0_iname)
+        if only_axis_0:
+            continue
 
-            if only_axis_0:
-                continue
+        # assign longest auto axis inames first
+        auto_axis_inames.sort(key=kernel.get_constant_iname_length, reverse=True)
 
-            next_axis = 0
-            while auto_axis_inames:
-                iname = auto_axis_inames.pop()
-                while next_axis in local_assigned_axes:
-                    next_axis += 1
+        next_axis = 0
+        if auto_axis_inames:
+            return assign_axis(auto_axis_inames.pop())
 
-                new_iname_to_tag[iname] = LocalIndexTag(next_axis)
-                local_assigned_axes.add(next_axis)
+    # We've seen all instructions and not punted to recursion/restart because
+    # of a new axis assignment.
 
-    return kernel.copy(iname_to_tag=new_iname_to_tag)
+    if only_axis_0:
+        # If we were only assigining axis 0, then assign all the remaining 
+        # axes next.
+        return assign_automatic_axes(kernel, only_axis_0=False)
+    else:
+        # If were already assigning all axes and got here, we're now done.
+        # All automatic axes are assigned.
+        return kernel
 
 
 
@@ -699,8 +729,6 @@ def insert_barriers(kernel, schedule, level=0):
 
 def generate_loop_schedules(kernel):
     kernel = realize_reduction(kernel)
-    check_double_use_of_hw_axes(kernel)
-    kernel = adjust_local_temp_var_storage(kernel)
 
     # {{{ check that all CSEs have been realized
 
@@ -714,8 +742,12 @@ def generate_loop_schedules(kernel):
 
     # }}}
 
-    kernel = add_automatic_dependencies(kernel)
     kernel = assign_automatic_axes(kernel)
+    kernel = add_automatic_dependencies(kernel)
+    kernel = adjust_local_temp_var_storage(kernel)
+
+    print kernel
+    check_for_double_use_of_hw_axes(kernel)
     check_for_unused_hw_axes(kernel)
 
     for gen_sched in generate_loop_schedules_internal(kernel):
@@ -727,7 +759,6 @@ def generate_loop_schedules(kernel):
 
 
 
-
 # {{{ schedule utilities
 
 def find_active_inames_at(kernel, sched_index):
-- 
GitLab