diff --git a/loopy/transform/padding.py b/loopy/transform/padding.py index 6cff7e022ba3b94e2612793f4d85aaaca28b6cef..915742fc1003352c2952f27d87dbfd31fb26c9d7 100644 --- a/loopy/transform/padding.py +++ b/loopy/transform/padding.py @@ -252,23 +252,7 @@ split_arg_axis = MovedFunctionDeprecationWrapper(split_array_dim) # {{{ split_array_axis -def split_array_axis(kernel, array_name, axis_nr, count, order="C"): - """ - :arg array: may name a temporary variable or an argument. - - :arg axis_nr: the (zero-based) index of the axis that should be split. - - :arg count: The group size to use in the split. - - :arg order: The way the split array axis should be linearized. - May be "C" or "F" to indicate C/Fortran (row/column)-major order. - - .. versionchanged:: 2016.2 - - There was a more complicated, dumber function called :func:`split_array_dim` - that had the role of this function in versions prior to 2016.2. - """ - +def _split_array_axis_inner(kernel, array_name, axis_nr, count, order="C"): if count == 1: return kernel @@ -384,6 +368,33 @@ def split_array_axis(kernel, array_name, axis_nr, count, order="C"): return kernel + +def split_array_axis(kernel, array_names, axis_nr, count, order="C"): + """ + :arg array: a list of names of temporary variables or arguments. May + also be a comma-separated string of these. + + :arg axis_nr: the (zero-based) index of the axis that should be split. + + :arg count: The group size to use in the split. + + :arg order: The way the split array axis should be linearized. + May be "C" or "F" to indicate C/Fortran (row/column)-major order. + + .. versionchanged:: 2016.2 + + There was a more complicated, dumber function called :func:`split_array_dim` + that had the role of this function in versions prior to 2016.2. + """ + + if isinstance(array_names, str): + array_names = [i.strip() for i in array_names.split(",") if i.strip()] + + for array_name in array_names: + kernel = _split_array_axis_inner(kernel, array_name, axis_nr, count, order) + + return kernel + # }}}