From 44c1beaa8b2d8e47bd3810c65410810ac0c52291 Mon Sep 17 00:00:00 2001 From: Tim Warburton <timwar@caam.rice.edu> Date: Tue, 25 Oct 2011 21:59:42 -0500 Subject: [PATCH] Guess which iname should be l.0 by ranking, not by pointing at one. This lets the higher-level routine choose its favorite, rather than leave it out of options once the top choice doesn't work. --- loopy/schedule.py | 47 ++++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/loopy/schedule.py b/loopy/schedule.py index 1c4a27652..fa224040d 100644 --- a/loopy/schedule.py +++ b/loopy/schedule.py @@ -367,7 +367,7 @@ def add_idempotence_and_automatic_dependencies(kernel): # {{{ guess good iname for local axis 0 -def guess_good_iname_for_axis_0(kernel, insn): +def get_axis_0_ranking(kernel, insn): from loopy.kernel import ImageArg, ScalarArg approximate_arg_values = dict( @@ -411,7 +411,7 @@ def guess_good_iname_for_axis_0(kernel, insn): from loopy.symbolic import CoefficientCollector - from pytools import argmin2, argmax2 + from pytools import argmin2 for aae in global_ary_acc_exprs: index_expr = aae.index @@ -435,23 +435,26 @@ def guess_good_iname_for_axis_0(kernel, insn): if old_stride is None or new_stride < old_stride: iname_to_stride[var_name] = new_stride - from pymbolic import evaluate - least_stride_iname, least_stride = argmin2(( - (iname, - evaluate(iname_to_stride[iname], approximate_arg_values)) - for iname in iname_to_stride), - return_value=True) + if iname_to_stride: + from pymbolic import evaluate + least_stride_iname, least_stride = argmin2(( + (iname, + evaluate(iname_to_stride[iname], approximate_arg_values)) + for iname in iname_to_stride), + return_value=True) - if least_stride == 1: - vote_strength = 1 - else: - vote_strength = 0.5 + if least_stride == 1: + vote_strength = 1 + else: + vote_strength = 0.5 - vote_count_for_l0[least_stride_iname] = ( - vote_count_for_l0.get(least_stride_iname, 0) - + vote_strength) + vote_count_for_l0[least_stride_iname] = ( + vote_count_for_l0.get(least_stride_iname, 0) + + vote_strength) - return argmax2(vote_count_for_l0.iteritems()) + return sorted((iname for iname in insn.all_inames()), + key=lambda iname: vote_count_for_l0.get(iname, 0), + reverse=True) # }}} @@ -531,13 +534,11 @@ def assign_automatic_axes(kernel, only_axis_0=True): if isinstance(tag, LocalIndexTag): assigned_local_axes.add(tag.axis) - axis0_iname = guess_good_iname_for_axis_0(kernel, insn) - - axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname) - ax0_tag = LocalIndexTag(0) - if (isinstance(axis0_iname_tag, AutoLocalIndexTagBase) - and 0 not in assigned_local_axes): - return assign_axis(axis0_iname, 0) + if 0 not in assigned_local_axes: + for axis0_iname in get_axis_0_ranking(kernel, insn): + axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname) + if isinstance(axis0_iname_tag, AutoLocalIndexTagBase): + return assign_axis(axis0_iname, 0) if only_axis_0: continue -- GitLab