diff --git a/loopy/array_buffer_map.py b/loopy/array_buffer_map.py index 0be16ce5b3d9b7af477f7ee849d0dbe8e6b90c7c..72fca8a4a17a183251c877e41b033539e2f00c85 100644 --- a/loopy/array_buffer_map.py +++ b/loopy/array_buffer_map.py @@ -28,7 +28,7 @@ from islpy import dim_type from loopy.symbolic import (get_dependencies, SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func -from pytools import Record +from pytools import Record, memoize_method from pymbolic import var @@ -64,7 +64,7 @@ def to_parameters_or_project_out(param_inames, set_inames, set): # {{{ construct storage->sweep map -def build_per_access_storage_to_domain_map(accdesc, domain, +def build_per_access_storage_to_domain_map(storage_axis_exprs, domain, storage_axis_names, prime_sweep_inames): @@ -91,7 +91,7 @@ def build_per_access_storage_to_domain_map(accdesc, domain, from loopy.symbolic import aff_from_expr - for saxis, sa_expr in zip(storage_axis_names, accdesc.storage_axis_exprs): + for saxis, sa_expr in zip(storage_axis_names, storage_axis_exprs): cns = isl.Constraint.equality_from_aff( aff_from_expr(set_space, var(saxis+"'") - prime_sweep_inames(sa_expr))) @@ -138,7 +138,7 @@ def build_global_storage_to_sweep_map(kernel, access_descriptors, # build footprint for accdesc in access_descriptors: stor2sweep = build_per_access_storage_to_domain_map( - accdesc, domain_dup_sweep, + accdesc.storage_axis_exprs, domain_dup_sweep, storage_axis_names, prime_sweep_inames) @@ -336,6 +336,11 @@ class ArrayToBufferMap(object): return convexify(domain) def is_access_descriptor_in_footprint(self, accdesc): + return self._is_access_descriptor_in_footprint_inner( + tuple(accdesc.storage_axis_exprs)) + + @memoize_method + def _is_access_descriptor_in_footprint_inner(self, storage_axis_exprs): # Make all inames except the sweep parameters. (The footprint may depend on # those.) (I.e. only leave sweep inames as out parameters.) @@ -347,7 +352,7 @@ class ArrayToBufferMap(object): set(global_s2s_par_dom.get_var_names(dim_type.param)) & self.kernel.all_inames()) - for arg in accdesc.storage_axis_exprs: + for arg in storage_axis_exprs: arg_inames.update(get_dependencies(arg)) arg_inames = frozenset(arg_inames) @@ -363,7 +368,8 @@ class ArrayToBufferMap(object): usage_domain = usage_domain.set_dim_name( dim_type.set, i, iname+"'") - stor2sweep = build_per_access_storage_to_domain_map(accdesc, + stor2sweep = build_per_access_storage_to_domain_map( + storage_axis_exprs, usage_domain, self.storage_axis_names, self.prime_sweep_inames) diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index e8409a9ce8a8bd20f2e2df81c20ce56b167ab090..a3a2a8d88bd6888f06d47a2f25b4b73eaf692ccd 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -384,4 +384,129 @@ def obj_involves_variable(obj, var_name): return False + +# {{{ performance tweak for dim_{min,max}: project first + +def _runs_in_integer_set(s, max_int=None): + if not s: + return + + if max_int is None: + max_int = max(s) + + i = 0 + while i < max_int: + if i in s: + start = i + + i += 1 + while i < max_int and i in s: + i += 1 + + end = i + + yield (start, end-start) + + else: + i += 1 + + +class TooManyInteractingDims(Exception): + pass + + +def _find_aff_dims(aff, dim_types_and_gen_dim_types): + result = [] + + for dt, gen_dt in dim_types_and_gen_dim_types: + for i in range(aff.dim(dt)): + if not aff.get_coefficient_val(dt, i).is_zero(): + result.append((gen_dt, i)) + + result = set(result) + + for i in range(aff.dim(dim_type.div)): + if not aff.get_coefficient_val(dim_type.div, i).is_zero(): + result.update(_find_aff_dims( + aff.get_div(i), + dim_types_and_gen_dim_types)) + + return result + + +def _transitive_closure(graph_dict): + pass + + +def _find_noninteracting_dims(obj, dt, idx, other_dt, stop_at=6): + if isinstance(obj, isl.BasicSet): + basics = [obj] + elif isinstance(obj, isl.Set): + basics = obj.get_basic_sets() + else: + raise TypeError("unsupported arg type '%s'" % type(obj)) + + connections = [] + for bs in basics: + for c in bs.get_constraints(): + conn = _find_aff_dims( + c.get_aff(), + [(dim_type.param, dim_type.param), (dim_type.in_, dim_type.set)]) + if len(conn) > 1: + connections.append(conn) + + interacting = set([(dt, idx)]) + + while True: + changed_something = False + + # Compute the connected component near (dt, idx) by fixed point iteration + + for conn in connections: + prev_len = len(interacting) + + overlap = interacting & conn + if overlap: + interacting.update(conn) + + if len(interacting) != prev_len: + changed_something = True + + if len(interacting) >= stop_at: + raise TooManyInteractingDims() + + if not changed_something: + break + + return set(range(obj.dim(other_dt))) - set( + idx for dt, idx in interacting + if dt == other_dt) + + +def _eliminate_noninteracting(obj, dt, idx, other_dt): + obj = obj.compute_divs() + try: + nonint = _find_noninteracting_dims(obj, dt, idx, other_dt) + + except TooManyInteractingDims: + return obj + + for first, n in _runs_in_integer_set(nonint): + obj = obj.eliminate(other_dt, first, n) + + return obj + + +def dim_min_with_elimination(obj, idx): + obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) + return obj_elim.dim_min(idx) + + +def dim_max_with_elimination(obj, idx): + obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) + return obj_elim.dim_max(idx) + +# }}} + + # vim: foldmethod=marker diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 2973cd9e69e46c875e15e6c8674150373be8be1e..b59c40731d91d1689d2cd9c00884069d35f7856a 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -251,15 +251,17 @@ class SetOperationCacheManager: return result #print op, set.get_dim_name(dim_type.set, args[0]) - result = op(*args) + result = op(set, *args) bucket.append((set, op_name, args, result)) return result def dim_min(self, set, *args): - return self.op(set, "dim_min", set.dim_min, args) + from loopy.isl_helpers import dim_min_with_elimination + return self.op(set, "dim_min", dim_min_with_elimination, args) def dim_max(self, set, *args): - return self.op(set, "dim_max", set.dim_max, args) + from loopy.isl_helpers import dim_max_with_elimination + return self.op(set, "dim_max", dim_max_with_elimination, args) def base_index_and_length(self, set, iname, context=None): if not isinstance(iname, int):