diff --git a/islpy/__init__.py b/islpy/__init__.py index 47f80370eb31d64bd748a0084fc495fb3d7891fb..1c64e1b20b9b60d555ad5dd816846d22cfbaa67c 100644 --- a/islpy/__init__.py +++ b/islpy/__init__.py @@ -1211,38 +1211,81 @@ def align_two(obj1, obj2, across_dim_types=False): # {{{ performance tweak for dim_{min,max}: project first -def _find_noninteracting_dims(obj, dt, idx, other_dt): - candidate_dims = set(range(obj.dim(other_dt))) +class TooManyInteractingDims(Exception): + pass + +def _find_aff_dims(aff, dim_types_and_gen_dim_types): + result = [] + + for dt, gen_dt in dim_types_and_gen_dim_types: + for i in range(aff.dim(dt)): + if not aff.get_coefficient_val(dt, i).is_zero(): + result.append((gen_dt, i)) + + result = set(result) + + for i in range(aff.dim(dim_type.div)): + result.update(_find_aff_dims(aff.get_div(i))) + + return result + + +def _transitive_closure(graph_dict): + pass + + +def _find_noninteracting_dims(obj, dt, idx, other_dt, stop_at=6): if isinstance(obj, BasicSet): basics = [obj] elif isinstance(obj, Set): basics = obj.get_basic_sets() - elif isinstance(obj, BasicMap): - basics = [obj] - elif isinstance(obj, Map): - basics = obj.get_basic_maps() else: raise TypeError("unsupported arg type '%s'" % type(obj)) + connections = [] for bs in basics: for c in bs.get_constraints(): - if not c.involves_dims(dt, idx, 1): - continue + conn = _find_aff_dims( + c.get_aff(), + [(dim_type.param, dim_type.param), (dim_type.in_, dim_type.set)]) + if len(conn) > 1: + connections.append(conn) + + interacting = set([(dt, idx)]) + + while True: + changed_something = False + + # Compute the connected component near (dt, idx) by fixed point iteration - found_interacting = set() + for conn in connections: + prev_len = len(interacting) - for dim in candidate_dims: - if c.involves_dims(other_dt, dim, 1): - found_interacting.add(dim) + overlap = interacting & conn + if overlap: + interacting.update(conn) - candidate_dims -= found_interacting + if len(interacting) != prev_len: + changed_something = True - return candidate_dims + if len(interacting) >= stop_at: + raise TooManyInteractingDims() + + if not changed_something: + break + + return set(range(obj.dim(other_dt))) - set( + idx for dt, idx in interacting + if dt == other_dt) def _eliminate_noninteracting(obj, dt, idx, other_dt): - nonint = _find_noninteracting_dims(obj, dt, idx, other_dt) + try: + nonint = _find_noninteracting_dims(obj, dt, idx, other_dt) + + except TooManyInteractingDims: + return obj for first, n in _runs_in_integer_set(nonint): obj = obj.eliminate(other_dt, first, n) @@ -1251,13 +1294,13 @@ def _eliminate_noninteracting(obj, dt, idx, other_dt): def dim_min_with_elimination(obj, idx): - obj = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) - return obj.dim_min(idx) + obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) + return obj_elim.dim_min(idx) def dim_max_with_elimination(obj, idx): - obj = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) - return obj.dim_max(idx) + obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param) + return obj_elim.dim_max(idx) # }}}