From 44c1beaa8b2d8e47bd3810c65410810ac0c52291 Mon Sep 17 00:00:00 2001
From: Tim Warburton <timwar@caam.rice.edu>
Date: Tue, 25 Oct 2011 21:59:42 -0500
Subject: [PATCH] Guess which iname should be l.0 by ranking, not by pointing
 at one.

This lets the higher-level routine choose its favorite, rather
than leave it out of options once the top choice doesn't work.
---
 loopy/schedule.py | 47 ++++++++++++++++++++++++-----------------------
 1 file changed, 24 insertions(+), 23 deletions(-)

diff --git a/loopy/schedule.py b/loopy/schedule.py
index 1c4a27652..fa224040d 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -367,7 +367,7 @@ def add_idempotence_and_automatic_dependencies(kernel):
 
 # {{{ guess good iname for local axis 0
 
-def guess_good_iname_for_axis_0(kernel, insn):
+def get_axis_0_ranking(kernel, insn):
     from loopy.kernel import ImageArg, ScalarArg
 
     approximate_arg_values = dict(
@@ -411,7 +411,7 @@ def guess_good_iname_for_axis_0(kernel, insn):
 
     from loopy.symbolic import CoefficientCollector
 
-    from pytools import argmin2, argmax2
+    from pytools import argmin2
 
     for aae in global_ary_acc_exprs:
         index_expr = aae.index
@@ -435,23 +435,26 @@ def guess_good_iname_for_axis_0(kernel, insn):
                     if old_stride is None or new_stride < old_stride:
                         iname_to_stride[var_name] = new_stride
 
-        from pymbolic import evaluate
-        least_stride_iname, least_stride = argmin2((
-                (iname,
-                    evaluate(iname_to_stride[iname], approximate_arg_values))
-                for iname in iname_to_stride),
-                return_value=True)
+        if iname_to_stride:
+            from pymbolic import evaluate
+            least_stride_iname, least_stride = argmin2((
+                    (iname,
+                        evaluate(iname_to_stride[iname], approximate_arg_values))
+                    for iname in iname_to_stride),
+                    return_value=True)
 
-        if least_stride == 1:
-            vote_strength = 1
-        else:
-            vote_strength = 0.5
+            if least_stride == 1:
+                vote_strength = 1
+            else:
+                vote_strength = 0.5
 
-        vote_count_for_l0[least_stride_iname] = (
-                vote_count_for_l0.get(least_stride_iname, 0)
-                + vote_strength)
+            vote_count_for_l0[least_stride_iname] = (
+                    vote_count_for_l0.get(least_stride_iname, 0)
+                    + vote_strength)
 
-    return argmax2(vote_count_for_l0.iteritems())
+    return sorted((iname for iname in insn.all_inames()),
+            key=lambda iname: vote_count_for_l0.get(iname, 0),
+            reverse=True)
 
     # }}}
 
@@ -531,13 +534,11 @@ def assign_automatic_axes(kernel, only_axis_0=True):
             if isinstance(tag, LocalIndexTag):
                 assigned_local_axes.add(tag.axis)
 
-        axis0_iname = guess_good_iname_for_axis_0(kernel, insn)
-
-        axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
-        ax0_tag = LocalIndexTag(0)
-        if (isinstance(axis0_iname_tag, AutoLocalIndexTagBase)
-                and 0 not in assigned_local_axes):
-            return assign_axis(axis0_iname, 0)
+        if 0 not in assigned_local_axes:
+            for axis0_iname in get_axis_0_ranking(kernel, insn):
+                axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
+                if isinstance(axis0_iname_tag, AutoLocalIndexTagBase):
+                    return assign_axis(axis0_iname, 0)
 
         if only_axis_0:
             continue
-- 
GitLab