diff --git a/loopy/check.py b/loopy/check.py index 093ddfde967e4b44aaa2e68ef8ce6cf15331509f..f6bf3d88ba6dc655c6e9ea4df40beb9dedd1fff8 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -218,6 +218,7 @@ def check_for_data_dependent_parallel_bounds(kernel): "inames '%s'. This is not allowed (for now)." % (i, par, ", ".join(par_inames))) + class _AccessCheckMapper(WalkMapper): def __init__(self, kernel, domain, insn_id): self.kernel = kernel @@ -238,56 +239,27 @@ class _AccessCheckMapper(WalkMapper): shape = tv.shape if shape is not None: - index = expr.index + subscript = expr.index + + if not isinstance(subscript, tuple): + subscript = (subscript,) - if not isinstance(index, tuple): - index = (index,) + from loopy.symbolic import get_dependencies, get_access_range - from loopy.symbolic import get_dependencies, aff_from_expr available_vars = set(self.domain.get_var_dict()) - if (get_dependencies(index) <= available_vars + if (get_dependencies(subscript) <= available_vars and get_dependencies(shape) <= available_vars): - dims = len(index) - - # we build access_map as a set because (idiocy!) Affs - # cannot live on maps. - - # dims: [domain](dn)[storage] - access_map = self.domain - - if isinstance(access_map, isl.BasicSet): - access_map = isl.Set.from_basic_set(access_map) - - dn = access_map.dim(dim_type.set) - access_map = access_map.insert_dims(dim_type.set, dn, dims) - - for idim in xrange(dims): - idx_aff = aff_from_expr(access_map.get_space(), - index[idim]) - idx_aff = idx_aff.set_coefficient( - dim_type.in_, dn+idim, -1) - - access_map = access_map.add_constraint( - isl.Constraint.equality_from_aff(idx_aff)) - - access_map_as_map = isl.Map.universe(access_map.get_space()) - access_map_as_map = access_map_as_map.intersect_range(access_map) - access_map = access_map_as_map.move_dims( - dim_type.in_, 0, - dim_type.out, 0, dn) - del access_map_as_map - - access_range = access_map.range() - - if dims != len(shape): + if len(subscript) != len(shape): raise RuntimeError("subscript to '%s' in '%s' has the wrong " "number of indices (got: %d, expected: %d)" % ( expr.aggregate.name, expr, - dims, len(shape))) + len(subscript), len(shape))) + + access_range = get_access_range(self.domain, subscript) shape_domain = isl.BasicSet.universe(access_range.get_space()) - for idim in xrange(dims): + for idim in xrange(len(subscript)): from loopy.isl_helpers import make_slab slab = make_slab( shape_domain.get_space(), (dim_type.in_, idim), diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 460e7b742555790ce5de153e2441e189efcd7f8f..b2a06cda6ad3895cff6451a624d8a9110f256a6d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -915,7 +915,7 @@ class WildcardToUniqueVariableMapper(IdentityMapper): # }}} -# {{{ prime-adder +# {{{ prime ("'") adder class PrimeAdder(IdentityMapper): def __init__(self, which_vars): @@ -934,7 +934,6 @@ class PrimeAdder(IdentityMapper): else: return expr - # }}} @memoize @@ -944,6 +943,43 @@ def get_dependencies(expr): return frozenset(dep.name for dep in dep_mapper(expr)) +# {{{ get access range + +def get_access_range(domain, subscript): + dims = len(subscript) + + # we build access_map as a set because (idiocy!) Affs + # cannot live on maps. + + # dims: [domain](dn)[storage] + access_map = domain + + if isinstance(access_map, isl.BasicSet): + access_map = isl.Set.from_basic_set(access_map) + + dn = access_map.dim(dim_type.set) + access_map = access_map.insert_dims(dim_type.set, dn, dims) + + for idim in xrange(dims): + idx_aff = aff_from_expr(access_map.get_space(), + subscript[idim]) + idx_aff = idx_aff.set_coefficient( + dim_type.in_, dn+idim, -1) + + access_map = access_map.add_constraint( + isl.Constraint.equality_from_aff(idx_aff)) + + access_map_as_map = isl.Map.universe(access_map.get_space()) + access_map_as_map = access_map_as_map.intersect_range(access_map) + access_map = access_map_as_map.move_dims( + dim_type.in_, 0, + dim_type.out, 0, dn) + del access_map_as_map + + return access_map.range() + +# }}} + # vim: foldmethod=marker