diff --git a/loopy/padding.py b/loopy/padding.py index 714a921bce64434682217235e8f6c75581f1dbd3..8f747d413e0182f2a863bbe13d745362b993d1ee 100644 --- a/loopy/padding.py +++ b/loopy/padding.py @@ -41,7 +41,7 @@ class ArgAxisSplitHelper(ExpandingIdentityMapper): return ExpandingIdentityMapper.map_subscript(self, expr, expn_state) -def split_arg_axis(kernel, args_and_axes, count): +def split_arg_axis(kernel, args_and_axes, count, auto_split_inames=True): """ :arg args_and_axes: a list of tuples *(arg, axis_nr)* indicating that the index in *axis_nr* should be split. The tuples may @@ -51,6 +51,10 @@ def split_arg_axis(kernel, args_and_axes, count): If *args_and_axes* is a :class:`tuple`, it is automatically wrapped in a list, to make single splits easier. + :arg count: The group size to use in the split. + :arg auto_split_inames: Whether to automatically split inames + encountered in the specified indices. + Note that splits on the corresponding inames are carried out implicitly. The inames may *not* be split beforehand. (There's no *really* good reason for this--this routine is just not smart enough to deal with this.) @@ -163,29 +167,39 @@ def split_arg_axis(kernel, args_and_axes, count): idx = list(idx) axis_idx = idx[axis_nr] - from pymbolic.primitives import Variable - if not isinstance(axis_idx, Variable): - raise RuntimeError("found access '%s' in which axis %d is not a " - "single variable--cannot split (Have you tried to do the split " - "yourself, manually, beforehand? If so, you shouldn't.)" - % (expr, axis_nr)) - split_iname = expr.index[axis_nr].name - assert split_iname in kernel.all_inames() + if auto_split_inames: + from pymbolic.primitives import Variable + if not isinstance(axis_idx, Variable): + raise RuntimeError("found access '%s' in which axis %d is not a " + "single variable--cannot split " + "(Have you tried to do the split yourself, manually, " + "beforehand? If so, you shouldn't.)" + % (expr, axis_nr)) + + split_iname = expr.index[axis_nr].name + assert split_iname in kernel.all_inames() + + try: + outer_iname, inner_iname = split_vars[split_iname] + except KeyError: + outer_iname = var_name_gen(split_iname+"_outer") + inner_iname = var_name_gen(split_iname+"_inner") + split_vars[split_iname] = outer_iname, inner_iname - try: - outer_iname, inner_iname = split_vars[split_iname] - except KeyError: - outer_iname = var_name_gen(split_iname+"_outer") - inner_iname = var_name_gen(split_iname+"_inner") - split_vars[split_iname] = outer_iname, inner_iname + inner_index = Variable(inner_iname) + outer_index = Variable(outer_iname) + + else: + inner_index = axis_idx % count + outer_index = axis_idx // count - idx[axis_nr] = Variable(inner_iname) + idx[axis_nr] = inner_index if order == "F": - idx.insert(axis+1, Variable(outer_iname)) + idx.insert(axis+1, outer_index) elif order == "C": - idx.insert(axis, Variable(outer_iname)) + idx.insert(axis, outer_index) else: raise RuntimeError("order '%s' not understood" % order)