diff --git a/doc/reference.rst b/doc/reference.rst index 47805c2b16d97b33ad0eb565e96e8a8c8da1e557..975ef80a92816a15647b1d29265c9362883cec1c 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -206,7 +206,7 @@ Influencing data access .. autofunction:: change_arg_to_image -.. autofunction:: tag_data_axis +.. autofunction:: tag_data_axes Padding ^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index 273e8573e3f7399004480845f10fe8de778bec01..0d868cb47d5b405b268949281d7789162a5326b1 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1037,9 +1037,9 @@ def change_arg_to_image(knl, name): # }}} -# {{{ tag data axis +# {{{ tag data axes -def tag_data_axis(knl, ary_name, axis, tag): +def tag_data_axes(knl, ary_name, dim_tags): if ary_name in knl.temporary_variables: ary = knl.temporary_variables[ary_name] elif ary_name in knl.arg_dict: @@ -1047,9 +1047,9 @@ def tag_data_axis(knl, ary_name, axis, tag): else: raise NameError("array '%s' was not found" % ary_name) - new_dim_tags = list(ary.dim_tags) - from loopy.kernel.array import parse_array_dim_tag - new_dim_tags[axis] = parse_array_dim_tag(tag) + from loopy.kernel.array import parse_array_dim_tags + new_dim_tags = parse_array_dim_tags(dim_tags, + use_increasing_target_axes=ary.max_target_axes > 1) ary = ary.copy(dim_tags=tuple(new_dim_tags)) diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index f90100f008d4e6bcaca0e5d17a8c646919a57675..21db8c0d4527149ab3dc72c33f5d31adb6745234 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -167,8 +167,8 @@ def parse_array_dim_tags(dim_tags, use_increasing_target_axes=False): default_target_axis = 0 result = [] - for dt in dim_tags: - result.append(parse_array_dim_tag(dt, default_target_axis)) + for dim_tag in dim_tags: + result.append(parse_array_dim_tag(dim_tag, default_target_axis)) if use_increasing_target_axes: default_target_axis += 1 @@ -194,27 +194,27 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes, computed_stride_dim_tags = [[] for i in range(num_target_axes)] fixed_stride_dim_tags = [[] for i in range(num_target_axes)] - for i, dt in enumerate(dim_tags): - if isinstance(dt, VectorArrayDimTag): + for i, dim_tag in enumerate(dim_tags): + if isinstance(dim_tag, VectorArrayDimTag): if vector_dim is not None: raise ValueError("arg '%s' may only have one vector-tagged " "argument dimension" % name) vector_dim = i - elif isinstance(dt, FixedStrideArrayDimTag): - fixed_stride_dim_tags[dt.target_axis].append(i) + elif isinstance(dim_tag, FixedStrideArrayDimTag): + fixed_stride_dim_tags[dim_tag.target_axis].append(i) - elif isinstance(dt, ComputedStrideArrayDimTag): - if dt.order in "cC": - computed_stride_dim_tags[dt.target_axis].insert(0, i) - elif dt.order in "fF": - computed_stride_dim_tags[dt.target_axis].append(i) + elif isinstance(dim_tag, ComputedStrideArrayDimTag): + if dim_tag.order in "cC": + computed_stride_dim_tags[dim_tag.target_axis].insert(0, i) + elif dim_tag.order in "fF": + computed_stride_dim_tags[dim_tag.target_axis].append(i) else: raise ValueError("invalid value '%s' for " - "ComputedStrideArrayDimTag.order" % dt.order) + "ComputedStrideArrayDimTag.order" % dim_tag.order) - elif isinstance(dt, SeparateArrayArrayDimTag): + elif isinstance(dim_tag, SeparateArrayArrayDimTag): pass else: @@ -422,6 +422,10 @@ class ArrayBase(Record): # }}} + if dim_tags is not None: + dim_tags = parse_array_dim_tags(dim_tags, + use_increasing_target_axes=self.max_target_axes > 1) + # {{{ determine number of user axes num_user_axes = None @@ -435,7 +439,8 @@ class ArrayBase(Record): else: if new_num_user_axes != num_user_axes: raise ValueError("contradictory values for number of dimensions " - "from shape, strides, or dim_tags") + "of array '%s' from shape, strides, or dim_tags" + % name) del new_num_user_axes @@ -449,21 +454,19 @@ class ArrayBase(Record): order = "C" if dim_tags is None and num_user_axes is not None and order is not None: - dim_tags = num_user_axes*[order] + dim_tags = parse_array_dim_tags(num_user_axes*[order], + use_increasing_target_axes=self.max_target_axes > 1) order = None # }}} if dim_tags is not None: - dim_tags = parse_array_dim_tags(dim_tags, - use_increasing_target_axes=self.max_target_axes > 1) - # {{{ find number of target axes target_axes = set() - for dt in dim_tags: - if isinstance(dt, _StrideArrayDimTagBase): - target_axes.add(dt.target_axis) + for dim_tag in dim_tags: + if isinstance(dim_tag, _StrideArrayDimTagBase): + target_axes.add(dim_tag.target_axis) if target_axes != set(xrange(len(target_axes))): raise ValueError("target axes for variable '%s' are non-" @@ -541,9 +544,9 @@ class ArrayBase(Record): def num_target_axes(self): target_axes = set() - for dt in self.dim_tags: - if isinstance(dt, _StrideArrayDimTagBase): - target_axes.add(dt.target_axis) + for dim_tag in self.dim_tags: + if isinstance(dim_tag, _StrideArrayDimTagBase): + target_axes.add(dim_tag.target_axis) return len(target_axes) @@ -569,7 +572,8 @@ class ArrayBase(Record): kwargs["shape"] = tuple(mapper(s) for s in self.shape) if self.dim_tags is not None: - kwargs["dim_tags"] = [dt.map_expr(mapper) for dt in self.dim_tags] + kwargs["dim_tags"] = [dim_tag.map_expr(mapper) + for dim_tag in self.dim_tags] # offset is not an expression, do not map.