diff --git a/loopy/transform/padding.py b/loopy/transform/padding.py index 021c5bfab31584a19cc7da87964f9f8cdc0e08ec..5557b00658b3f64c1656419dbfa370d50e76e40c 100644 --- a/loopy/transform/padding.py +++ b/loopy/transform/padding.py @@ -42,6 +42,8 @@ class ArrayAxisSplitHelper(RuleAwareIdentityMapper): return super(ArrayAxisSplitHelper, self).map_subscript(expr, expn_state) +# {{{ split_array_dim (deprecated since June 2016) + def split_array_dim(kernel, arrays_and_axes, count, auto_split_inames=True, split_kwargs=None): """ @@ -213,8 +215,9 @@ def split_array_dim(kernel, arrays_and_axes, count, auto_split_inames=True, outer_index = Variable(outer_iname) else: - inner_index = axis_idx % count - outer_index = axis_idx // count + from loopy.symbolic import simplify_using_affs + inner_index = simplify_using_affs(kernel, axis_idx % count) + outer_index = simplify_using_affs(kernel, axis_idx // count) idx[axis_nr] = inner_index @@ -244,6 +247,142 @@ def split_array_dim(kernel, arrays_and_axes, count, auto_split_inames=True, split_arg_axis = MovedFunctionDeprecationWrapper(split_array_dim) +# }}} + + +# {{{ split_array_axis + +def split_array_axis(kernel, array_name, axis_nr, count, order="C"): + """ + :arg array: may name a temporary variable or an argument. + + :arg axis_nr: the (zero-based) index of the axis that should be split. + + :arg count: The group size to use in the split. + + :arg order: The way the split array axis should be linearized. + May be "C" or "F" to indicate C/Fortran (row/column)-major order. + """ + + if count == 1: + return kernel + + # {{{ adjust arrays + + from loopy.kernel.tools import ArrayChanger + + achng = ArrayChanger(kernel, array_name) + ary = achng.get() + + from pytools import div_ceil + + # {{{ adjust shape + + new_shape = ary.shape + if new_shape is not None: + new_shape = list(new_shape) + axis_len = new_shape[axis_nr] + new_shape[axis_nr] = count + outer_len = div_ceil(axis_len, count) + + if order == "F": + new_shape.insert(axis_nr+1, outer_len) + elif order == "C": + new_shape.insert(axis_nr, outer_len) + else: + raise RuntimeError("order '%s' not understood" % order) + new_shape = tuple(new_shape) + + # }}} + + # {{{ adjust dim tags + + if ary.dim_tags is None: + raise RuntimeError("dim_tags of '%s' are not known" % array_name) + new_dim_tags = list(ary.dim_tags) + + old_dim_tag = ary.dim_tags[axis_nr] + + from loopy.kernel.array import FixedStrideArrayDimTag + if not isinstance(old_dim_tag, FixedStrideArrayDimTag): + raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" + % (axis_nr, array_name)) + + old_stride = old_dim_tag.stride + outer_stride = count*old_stride + + if order == "F": + new_dim_tags.insert(axis_nr+1, FixedStrideArrayDimTag(outer_stride)) + elif order == "C": + new_dim_tags.insert(axis_nr, FixedStrideArrayDimTag(outer_stride)) + else: + raise RuntimeError("order '%s' not understood" % order) + + new_dim_tags = tuple(new_dim_tags) + + # }}} + + # {{{ adjust dim_names + + new_dim_names = ary.dim_names + if new_dim_names is not None: + new_dim_names = list(new_dim_names) + existing_name = new_dim_names[axis_nr] + new_dim_names[axis_nr] = existing_name + "_inner" + outer_name = existing_name + "_outer" + + if order == "F": + new_dim_names.insert(axis_nr+1, outer_name) + elif order == "C": + new_dim_names.insert(axis_nr, outer_name) + else: + raise RuntimeError("order '%s' not understood" % order) + new_dim_names = tuple(new_dim_names) + + # }}} + + kernel = achng.with_changed_array(ary.copy( + shape=new_shape, dim_tags=new_dim_tags, dim_names=new_dim_names)) + + # }}} + + var_name_gen = kernel.get_var_name_generator() + + def split_access_axis(expr): + idx = expr.index + if not isinstance(idx, tuple): + idx = (idx,) + idx = list(idx) + + axis_idx = idx[axis_nr] + + from loopy.symbolic import simplify_using_affs + inner_index = simplify_using_affs(kernel, axis_idx % count) + outer_index = simplify_using_affs(kernel, axis_idx // count) + + idx[axis_nr] = inner_index + + if order == "F": + idx.insert(axis_nr+1, outer_index) + elif order == "C": + idx.insert(axis_nr, outer_index) + else: + raise RuntimeError("order '%s' not understood" % order) + + return expr.aggregate.index(tuple(idx)) + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, var_name_gen) + aash = ArrayAxisSplitHelper(rule_mapping_context, + set([array_name]), split_access_axis) + kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel)) + + return kernel + +# }}} + + +# {{{ find_padding_multiple def find_padding_multiple(kernel, variable, axis, align_bytes, allowed_waste=0.1): arg = kernel.arg_dict[variable] @@ -278,6 +417,10 @@ def find_padding_multiple(kernel, variable, axis, align_bytes, allowed_waste=0.1 multiple += 1 +# }}} + + +# {{{ add_padding def add_padding(kernel, variable, axis, align_bytes): arg_to_idx = dict((arg.name, i) for i, arg in enumerate(kernel.args)) @@ -312,5 +455,7 @@ def add_padding(kernel, variable, axis, align_bytes): return kernel.copy(args=new_args) +# }}} + # vim: foldmethod=marker