diff --git a/loopy/schedule.py b/loopy/schedule.py index 1c4a276527717c2a232722d0929e561047f253ef..fa224040d19c48348d85c3e291dd381f1c8ed3ca 100644 --- a/loopy/schedule.py +++ b/loopy/schedule.py @@ -367,7 +367,7 @@ def add_idempotence_and_automatic_dependencies(kernel): # {{{ guess good iname for local axis 0 -def guess_good_iname_for_axis_0(kernel, insn): +def get_axis_0_ranking(kernel, insn): from loopy.kernel import ImageArg, ScalarArg approximate_arg_values = dict( @@ -411,7 +411,7 @@ def guess_good_iname_for_axis_0(kernel, insn): from loopy.symbolic import CoefficientCollector - from pytools import argmin2, argmax2 + from pytools import argmin2 for aae in global_ary_acc_exprs: index_expr = aae.index @@ -435,23 +435,26 @@ def guess_good_iname_for_axis_0(kernel, insn): if old_stride is None or new_stride < old_stride: iname_to_stride[var_name] = new_stride - from pymbolic import evaluate - least_stride_iname, least_stride = argmin2(( - (iname, - evaluate(iname_to_stride[iname], approximate_arg_values)) - for iname in iname_to_stride), - return_value=True) + if iname_to_stride: + from pymbolic import evaluate + least_stride_iname, least_stride = argmin2(( + (iname, + evaluate(iname_to_stride[iname], approximate_arg_values)) + for iname in iname_to_stride), + return_value=True) - if least_stride == 1: - vote_strength = 1 - else: - vote_strength = 0.5 + if least_stride == 1: + vote_strength = 1 + else: + vote_strength = 0.5 - vote_count_for_l0[least_stride_iname] = ( - vote_count_for_l0.get(least_stride_iname, 0) - + vote_strength) + vote_count_for_l0[least_stride_iname] = ( + vote_count_for_l0.get(least_stride_iname, 0) + + vote_strength) - return argmax2(vote_count_for_l0.iteritems()) + return sorted((iname for iname in insn.all_inames()), + key=lambda iname: vote_count_for_l0.get(iname, 0), + reverse=True) # }}} @@ -531,13 +534,11 @@ def assign_automatic_axes(kernel, only_axis_0=True): if isinstance(tag, LocalIndexTag): assigned_local_axes.add(tag.axis) - axis0_iname = guess_good_iname_for_axis_0(kernel, insn) - - axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname) - ax0_tag = LocalIndexTag(0) - if (isinstance(axis0_iname_tag, AutoLocalIndexTagBase) - and 0 not in assigned_local_axes): - return assign_axis(axis0_iname, 0) + if 0 not in assigned_local_axes: + for axis0_iname in get_axis_0_ranking(kernel, insn): + axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname) + if isinstance(axis0_iname_tag, AutoLocalIndexTagBase): + return assign_axis(axis0_iname, 0) if only_axis_0: continue