From 618cc90313f7b6807cb131331d85159c980fd38b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 8 Nov 2011 19:51:38 -0500
Subject: [PATCH] Pick not just axis 0, but all auto axes by lowest available
 stride.

---
 MEMO                |   4 +-
 loopy/preprocess.py | 118 ++++++++++++++++++--------------------------
 loopy/symbolic.py   |   2 +-
 3 files changed, 52 insertions(+), 72 deletions(-)

diff --git a/MEMO b/MEMO
index eb389e7a9..f09d76315 100644
--- a/MEMO
+++ b/MEMO
@@ -41,8 +41,6 @@ To-do
 
 - dim_max caching
 
-- Pick not just axis 0, but all axes by lowest available stride
-
 - Fix all tests
 
 - Deal with equality constraints.
@@ -87,6 +85,8 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- Pick not just axis 0, but all axes by lowest available stride
+
 - Scheduler tries too many boostability-related options
 
 - Automatically generate testing code vs. sequential.
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 30baf5e07..fd851d3a6 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -332,9 +332,9 @@ def limit_boostability(kernel):
 
 # }}}
 
-# {{{ guess good iname for local axis 0
+# {{{ rank inames by stride
 
-def get_axis_0_ranking(kernel, insn):
+def get_auto_axis_iname_ranking_by_stride(kernel, insn):
     from loopy.kernel import ImageArg, ScalarArg
 
     approximate_arg_values = dict(
@@ -371,10 +371,10 @@ def get_axis_0_ranking(kernel, insn):
 
     # }}}
 
-    # {{{ figure out axis 0 candidates
+    # {{{ figure out automatic-axis inames
 
     from loopy.kernel import AutoLocalIndexTagBase
-    axis0_candidates = set(
+    auto_axis_inames = set(
             iname
             for iname in kernel.insn_inames(insn)
             if isinstance(kernel.iname_to_tag.get(iname),
@@ -384,15 +384,11 @@ def get_axis_0_ranking(kernel, insn):
 
     # {{{ figure out which iname should get mapped to local axis 0
 
-    # maps inames to vote counts
-    vote_count_for_l0 = {}
+    # maps inames to "aggregate stride"
+    aggregate_strides = {}
 
     from loopy.symbolic import CoefficientCollector
 
-    from pytools import argmin2
-
-    saw_relevant_access = False
-
     for aae in global_ary_acc_exprs:
         index_expr = aae.index
         if not isinstance(index_expr, tuple):
@@ -405,44 +401,29 @@ def get_axis_0_ranking(kernel, insn):
         if ary_strides is None and len(index_expr) == 1:
             ary_strides = (1,)
 
-        # {{{ construct iname_to_stride
+        # {{{ construct iname_to_stride_expr
 
-        iname_to_stride = {}
+        iname_to_stride_expr = {}
         for iexpr_i, stride in zip(index_expr, ary_strides):
             coeffs = CoefficientCollector()(iexpr_i)
             for var_name, coeff in coeffs.iteritems():
-                if var_name != 1:
+                if var_name in auto_axis_inames: # excludes '1', i.e.  the constant
                     new_stride = coeff*stride
-                    old_stride = iname_to_stride.get(var_name, None)
+                    old_stride = iname_to_stride_expr.get(var_name, None)
                     if old_stride is None or new_stride < old_stride:
-                        iname_to_stride[var_name] = new_stride
+                        iname_to_stride_expr[var_name] = new_stride
 
         # }}}
 
-        if set(iname_to_stride.keys()) & axis0_candidates:
-            saw_relevant_access = True
-
-        if iname_to_stride:
-            from pymbolic import evaluate
-            least_stride_iname, least_stride = argmin2((
-                    (iname,
-                        evaluate(iname_to_stride[iname], approximate_arg_values))
-                    for iname in iname_to_stride),
-                    return_value=True)
-
-            if least_stride == 1:
-                vote_strength = 1
-            else:
-                vote_strength = 0.5
-
-            vote_count_for_l0[least_stride_iname] = (
-                    vote_count_for_l0.get(least_stride_iname, 0)
-                    + vote_strength)
+        from pymbolic import evaluate
+        for iname, stride_expr in iname_to_stride_expr.iteritems():
+            stride = evaluate(stride_expr, approximate_arg_values)
+            aggregate_strides[iname] = aggregate_strides.get(iname, 0) + stride
 
-    if saw_relevant_access:
-        return sorted((iname for iname in kernel.insn_inames(insn)),
-                key=lambda iname: vote_count_for_l0.get(iname, 0),
-                reverse=True)
+    if aggregate_strides:
+        import sys
+        return  sorted((iname for iname in kernel.insn_inames(insn)),
+                key=lambda iname: aggregate_strides.get(iname, sys.maxint))
     else:
         return None
 
@@ -452,7 +433,7 @@ def get_axis_0_ranking(kernel, insn):
 
 # {{{ assign automatic axes
 
-def assign_automatic_axes(kernel, phase="axis0", local_size=None):
+def assign_automatic_axes(kernel, axis=0, local_size=None):
     from loopy.kernel import (AutoLocalIndexTagBase, LocalIndexTag,
             UnrollTag)
 
@@ -466,7 +447,7 @@ def assign_automatic_axes(kernel, phase="axis0", local_size=None):
 
     # {{{ axis assignment helper function
 
-    def assign_axis(iname, axis=None):
+    def assign_axis(recursion_axis, iname, axis=None):
         """Assign iname to local axis *axis* and start over by calling
         the surrounding function assign_automatic_axes.
 
@@ -516,23 +497,22 @@ def assign_automatic_axes(kernel, phase="axis0", local_size=None):
                         split_dimension(kernel, iname, inner_length=local_size[axis],
                             outer_tag=UnrollTag(), inner_tag=new_tag,
                             do_tagged_check=False),
-                        phase=phase, local_size=local_size)
+                        axis=recursion_axis, local_size=local_size)
+
+        if not isinstance(kernel.iname_to_tag.get(iname), AutoLocalIndexTagBase):
+            raise RuntimeError("trying to reassign '%s'" % iname)
 
         new_iname_to_tag = kernel.iname_to_tag.copy()
         new_iname_to_tag[iname] = new_tag
         return assign_automatic_axes(kernel.copy(iname_to_tag=new_iname_to_tag),
-                phase=phase, local_size=local_size)
+                axis=recursion_axis, local_size=local_size)
 
     # }}}
 
     # {{{ main assignment loop
 
-    # assignment proceeds in two phases:
-
-    # - "axis0": Only axis 0 is assigned on instructions that carry out
-    #   global array access based on l.auto axes
-    #
-    # - "rest": All other l.auto axes are assigned haphazardly.
+    # assignment proceeds in one phase per axis, each time assigning the
+    # smallest-stride available iname to the current axis
 
     for insn in kernel.instructions:
         auto_axis_inames = [
@@ -551,37 +531,37 @@ def assign_automatic_axes(kernel, phase="axis0", local_size=None):
             if isinstance(tag, LocalIndexTag):
                 assigned_local_axes.add(tag.axis)
 
-        if 0 < len(local_size) and 0 not in assigned_local_axes:
-            axis0_ranking = get_axis_0_ranking(kernel, insn)
-            if axis0_ranking is not None:
-                for axis0_iname in axis0_ranking:
-                    axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
-                    if isinstance(axis0_iname_tag, AutoLocalIndexTagBase):
-                        return assign_axis(axis0_iname, 0)
+        if axis < len(local_size):
+            # "valid" pass: try to assign a given axis
 
-        if phase == "axis0":
-            continue
+            if axis not in assigned_local_axes:
+                iname_ranking = get_auto_axis_iname_ranking_by_stride(kernel, insn)
+                if iname_ranking is not None:
+                    for iname in iname_ranking:
+                        prev_tag = kernel.iname_to_tag.get(iname)
+                        if isinstance(prev_tag, AutoLocalIndexTagBase):
+                            return assign_axis(axis, iname, axis)
+
+        else:
+            # "invalid" pass: There are still unassigned axis after the
+            #  numbered "valid" passes--assign the remainder by length.
 
-        # assign longest auto axis inames first
-        auto_axis_inames.sort(key=kernel.get_constant_iname_length, reverse=True)
+            # assign longest auto axis inames first
+            auto_axis_inames.sort(key=kernel.get_constant_iname_length, reverse=True)
 
-        if auto_axis_inames:
-            return assign_axis(auto_axis_inames.pop())
+            if auto_axis_inames:
+                return assign_axis(axis, auto_axis_inames.pop())
 
     # }}}
 
     # We've seen all instructions and not punted to recursion/restart because
     # of a new axis assignment.
 
-    if phase == "axis0":
-        # If we were only assigining axis 0, then assign all the remaining
-        # axes next.
-        return assign_automatic_axes(kernel, phase="rest",
-                local_size=local_size)
-    else:
-        # If were already assigning all axes and got here, we're now done.
-        # All automatic axes are assigned.
+    if axis >= len(local_size):
         return kernel
+    else:
+        return assign_automatic_axes(kernel, axis=axis+1,
+                local_size=local_size)
 
 # }}}
 
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 427c9be5c..56cd896eb 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -287,7 +287,7 @@ class ArrayAccessFinder(CombineMapper):
 
 class LoopyCCodeMapper(CCodeMapper):
     def __init__(self, kernel, cse_name_list=[], var_subst_map={},
-            with_annotation=False):
+            with_annotation=True):
         def constant_mapper(c):
             if isinstance(c, float):
                 # FIXME: type-variable
-- 
GitLab