Skip to content
Snippets Groups Projects
Commit de1a9c4d authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Allow split_arg_axis to work without splitting inames

parent 1cf7bf97
No related branches found
No related tags found
No related merge requests found
......@@ -41,7 +41,7 @@ class ArgAxisSplitHelper(ExpandingIdentityMapper):
return ExpandingIdentityMapper.map_subscript(self, expr, expn_state)
def split_arg_axis(kernel, args_and_axes, count):
def split_arg_axis(kernel, args_and_axes, count, auto_split_inames=True):
"""
:arg args_and_axes: a list of tuples *(arg, axis_nr)* indicating
that the index in *axis_nr* should be split. The tuples may
......@@ -51,6 +51,10 @@ def split_arg_axis(kernel, args_and_axes, count):
If *args_and_axes* is a :class:`tuple`, it is automatically
wrapped in a list, to make single splits easier.
:arg count: The group size to use in the split.
:arg auto_split_inames: Whether to automatically split inames
encountered in the specified indices.
Note that splits on the corresponding inames are carried out implicitly.
The inames may *not* be split beforehand. (There's no *really* good reason
for this--this routine is just not smart enough to deal with this.)
......@@ -163,29 +167,39 @@ def split_arg_axis(kernel, args_and_axes, count):
idx = list(idx)
axis_idx = idx[axis_nr]
from pymbolic.primitives import Variable
if not isinstance(axis_idx, Variable):
raise RuntimeError("found access '%s' in which axis %d is not a "
"single variable--cannot split (Have you tried to do the split "
"yourself, manually, beforehand? If so, you shouldn't.)"
% (expr, axis_nr))
split_iname = expr.index[axis_nr].name
assert split_iname in kernel.all_inames()
if auto_split_inames:
from pymbolic.primitives import Variable
if not isinstance(axis_idx, Variable):
raise RuntimeError("found access '%s' in which axis %d is not a "
"single variable--cannot split "
"(Have you tried to do the split yourself, manually, "
"beforehand? If so, you shouldn't.)"
% (expr, axis_nr))
split_iname = expr.index[axis_nr].name
assert split_iname in kernel.all_inames()
try:
outer_iname, inner_iname = split_vars[split_iname]
except KeyError:
outer_iname = var_name_gen(split_iname+"_outer")
inner_iname = var_name_gen(split_iname+"_inner")
split_vars[split_iname] = outer_iname, inner_iname
try:
outer_iname, inner_iname = split_vars[split_iname]
except KeyError:
outer_iname = var_name_gen(split_iname+"_outer")
inner_iname = var_name_gen(split_iname+"_inner")
split_vars[split_iname] = outer_iname, inner_iname
inner_index = Variable(inner_iname)
outer_index = Variable(outer_iname)
else:
inner_index = axis_idx % count
outer_index = axis_idx // count
idx[axis_nr] = Variable(inner_iname)
idx[axis_nr] = inner_index
if order == "F":
idx.insert(axis+1, Variable(outer_iname))
idx.insert(axis+1, outer_index)
elif order == "C":
idx.insert(axis, Variable(outer_iname))
idx.insert(axis, outer_index)
else:
raise RuntimeError("order '%s' not understood" % order)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment