From cb15587b71e6a7805edc1bee7feddf0b58584497 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 12 Jul 2015 22:05:49 -0500 Subject: [PATCH] Bring (mildly fixed) perf tweak for dim_{min,max} over from islpy --- loopy/isl_helpers.py | 123 ++++++++++++++++++++++++++++++++++++++++++ loopy/kernel/tools.py | 7 +-- 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index e8409a9ce..9889fa12e 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -384,4 +384,127 @@ def obj_involves_variable(obj, var_name): return False + +# {{{ performance tweak for dim_{min,max}: project first + +def _runs_in_integer_set(s, max_int=None): + if not s: + return + + if max_int is None: + max_int = max(s) + + i = 0 + while i < max_int: + if i in s: + start = i + + i += 1 + while i < max_int and i in s: + i += 1 + + end = i + + yield (start, end-start) + + else: + i += 1 + + +class TooManyInteractingDims(Exception): + pass + + +def _find_aff_dims(aff, dim_types_and_gen_dim_types): + result = [] + + for dt, gen_dt in dim_types_and_gen_dim_types: + for i in range(aff.dim(dt)): + if not aff.get_coefficient_val(dt, i).is_zero(): + result.append((gen_dt, i)) + + result = set(result) + + for i in range(aff.dim(dim_type.div)): + result.update(_find_aff_dims( + aff.get_div(i), + dim_types_and_gen_dim_types)) + + return result + + +def _transitive_closure(graph_dict): + pass + + +def _find_noninteracting_dims(obj, dt, idx, other_dt, stop_at=6): + if isinstance(obj, isl.BasicSet): + basics = [obj] + elif isinstance(obj, isl.Set): + basics = obj.get_basic_sets() + else: + raise TypeError("unsupported arg type '%s'" % type(obj)) + + connections = [] + for bs in basics: + for c in bs.get_constraints(): + conn = _find_aff_dims( + c.get_aff(), + [(dim_type.param, dim_type.param), (dim_type.in_, dim_type.set)]) + if len(conn) > 1: + connections.append(conn) + + interacting = set([(dt, idx)]) + + while True: + changed_something = False + + # Compute the connected component near (dt, idx) by fixed point iteration + + for conn in connections: + prev_len = len(interacting) + + overlap = interacting & conn + if overlap: + interacting.update(conn) + + if len(interacting) != prev_len: + changed_something = True + + if len(interacting) >= stop_at: + raise TooManyInteractingDims() + + if not changed_something: + break + + return set(range(obj.dim(other_dt))) - set( + idx for dt, idx in interacting + if dt == other_dt) + + +def _eliminate_noninteracting(obj, dt, idx, other_dt): + try: + nonint = _find_noninteracting_dims(obj, dt, idx, other_dt) + + except TooManyInteractingDims: + return obj + + for first, n in _runs_in_integer_set(nonint): + obj = obj.eliminate(other_dt, first, n) + + return obj + + +def dim_min_with_elimination(obj, idx): + obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) + return obj_elim.dim_min(idx) + + +def dim_max_with_elimination(obj, idx): + obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) + return obj_elim.dim_max(idx) + +# }}} + + # vim: foldmethod=marker diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index d1093c01a..b59c40731 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -28,7 +28,6 @@ THE SOFTWARE. import numpy as np -import islpy as isl from islpy import dim_type from loopy.diagnostic import LoopyError @@ -257,10 +256,12 @@ class SetOperationCacheManager: return result def dim_min(self, set, *args): - return self.op(set, "dim_min", isl.dim_min_with_elimination, args) + from loopy.isl_helpers import dim_min_with_elimination + return self.op(set, "dim_min", dim_min_with_elimination, args) def dim_max(self, set, *args): - return self.op(set, "dim_max", isl.dim_max_with_elimination, args) + from loopy.isl_helpers import dim_max_with_elimination + return self.op(set, "dim_max", dim_max_with_elimination, args) def base_index_and_length(self, set, iname, context=None): if not isinstance(iname, int): -- GitLab