Skip to content
Snippets Groups Projects
Commit 44c1beaa authored by Tim Warburton's avatar Tim Warburton
Browse files

Guess which iname should be l.0 by ranking, not by pointing at one.

This lets the higher-level routine choose its favorite, rather
than leave it out of options once the top choice doesn't work.
parent 16290f2b
No related branches found
No related tags found
No related merge requests found
......@@ -367,7 +367,7 @@ def add_idempotence_and_automatic_dependencies(kernel):
# {{{ guess good iname for local axis 0
def guess_good_iname_for_axis_0(kernel, insn):
def get_axis_0_ranking(kernel, insn):
from loopy.kernel import ImageArg, ScalarArg
approximate_arg_values = dict(
......@@ -411,7 +411,7 @@ def guess_good_iname_for_axis_0(kernel, insn):
from loopy.symbolic import CoefficientCollector
from pytools import argmin2, argmax2
from pytools import argmin2
for aae in global_ary_acc_exprs:
index_expr = aae.index
......@@ -435,23 +435,26 @@ def guess_good_iname_for_axis_0(kernel, insn):
if old_stride is None or new_stride < old_stride:
iname_to_stride[var_name] = new_stride
from pymbolic import evaluate
least_stride_iname, least_stride = argmin2((
(iname,
evaluate(iname_to_stride[iname], approximate_arg_values))
for iname in iname_to_stride),
return_value=True)
if iname_to_stride:
from pymbolic import evaluate
least_stride_iname, least_stride = argmin2((
(iname,
evaluate(iname_to_stride[iname], approximate_arg_values))
for iname in iname_to_stride),
return_value=True)
if least_stride == 1:
vote_strength = 1
else:
vote_strength = 0.5
if least_stride == 1:
vote_strength = 1
else:
vote_strength = 0.5
vote_count_for_l0[least_stride_iname] = (
vote_count_for_l0.get(least_stride_iname, 0)
+ vote_strength)
vote_count_for_l0[least_stride_iname] = (
vote_count_for_l0.get(least_stride_iname, 0)
+ vote_strength)
return argmax2(vote_count_for_l0.iteritems())
return sorted((iname for iname in insn.all_inames()),
key=lambda iname: vote_count_for_l0.get(iname, 0),
reverse=True)
# }}}
......@@ -531,13 +534,11 @@ def assign_automatic_axes(kernel, only_axis_0=True):
if isinstance(tag, LocalIndexTag):
assigned_local_axes.add(tag.axis)
axis0_iname = guess_good_iname_for_axis_0(kernel, insn)
axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
ax0_tag = LocalIndexTag(0)
if (isinstance(axis0_iname_tag, AutoLocalIndexTagBase)
and 0 not in assigned_local_axes):
return assign_axis(axis0_iname, 0)
if 0 not in assigned_local_axes:
for axis0_iname in get_axis_0_ranking(kernel, insn):
axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
if isinstance(axis0_iname_tag, AutoLocalIndexTagBase):
return assign_axis(axis0_iname, 0)
if only_axis_0:
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment