diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 1ad83c0a718b60194cb5789ced7d0a9c7674de75..b82965914a2fd6eabb460e3e68f31131862c6330 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -26,6 +26,7 @@ THE SOFTWARE. import re +import six from six.moves import range, zip from six import iteritems @@ -289,9 +290,32 @@ def _parse_array_dim_tag(tag, default_target_axis, nesting_levels): nesting_level, pad_to=pad_to, target_axis=target_axis)) -def parse_array_dim_tags(dim_tags, n_axes=None, use_increasing_target_axes=False): +def parse_array_dim_tags(dim_tags, n_axes=None, use_increasing_target_axes=False, + dim_names=None): if isinstance(dim_tags, str): dim_tags = dim_tags.split(",") + if isinstance(dim_tags, dict): + dim_tags_dict = dim_tags + + if dim_names is None: + raise LoopyError("dim_tags may only be given as a dictionary if " + "dim_names is available") + + assert n_axes == len(dim_names) + + dim_tags = [None]*n_axes + for dim_name, val in six.iteritems(dim_tags_dict): + try: + dim_idx = dim_names.index(dim_name) + except ValueError: + raise LoopyError("'%s' does not name an array axis" % dim_name) + + dim_tags[dim_idx] = val + + for idim, dim_tag in enumerate(dim_tags): + if dim_tag is None: + raise LoopyError("array axis tag for axis %d (1-based) was not " + "set by passed dictionary" % (idim + 1)) default_target_axis = 0 @@ -673,7 +697,8 @@ class ArrayBase(Record): if dim_tags is not None: dim_tags = parse_array_dim_tags(dim_tags, n_axes=(len(shape) if shape_known else None), - use_increasing_target_axes=self.max_target_axes > 1) + use_increasing_target_axes=self.max_target_axes > 1, + dim_names=dim_names) # {{{ determine number of user axes @@ -707,7 +732,8 @@ class ArrayBase(Record): if dim_tags is None and num_user_axes is not None and order is not None: dim_tags = parse_array_dim_tags(num_user_axes*[order], n_axes=num_user_axes, - use_increasing_target_axes=self.max_target_axes > 1) + use_increasing_target_axes=self.max_target_axes > 1, + dim_names=dim_names) order = None # }}} diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 7b1deb7951392e2e0c46360f8fd979ebf5aedb37..02499ded2c912f4f0bbbdcfbfc4a3d3d161414b3 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -307,7 +307,8 @@ def tag_data_axes(knl, ary_names, dim_tags): from loopy.kernel.array import parse_array_dim_tags new_dim_tags = parse_array_dim_tags(dim_tags, n_axes=ary.num_user_axes(), - use_increasing_target_axes=ary.max_target_axes > 1) + use_increasing_target_axes=ary.max_target_axes > 1, + dim_names=ary.dim_names) ary = ary.copy(dim_tags=tuple(new_dim_tags))