diff --git a/islpy/__init__.py b/islpy/__init__.py index cd6ef6df5a7bf44bd3d93b3d5b450c9db8ecc8f6..45af7505ae407054881777af397e3a3224f4c827 100644 --- a/islpy/__init__.py +++ b/islpy/__init__.py @@ -1173,7 +1173,7 @@ def _align_dim_type(tgt_dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names): return obj -def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=False): +def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=None): """ Try to make the space in which *obj* lives the same as that of *tgt* by adding/matching named dimensions. @@ -1182,6 +1182,11 @@ def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=False): has more dimensions than *tgt*. """ + if across_dim_types is not None: + warn("across_dim_types is deprecated and should no longer be used. " + "It never had any effect anyway.", + DeprecationWarning, stacklevel=2) + have_any_param_domains = ( isinstance(obj, (Set, BasicSet)) and isinstance(tgt, (Set, BasicSet)) @@ -1198,28 +1203,19 @@ def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=False): else: dim_types = _CHECK_DIM_TYPES - if across_dim_types: - obj_names = [ - obj.get_dim_name(dt, i) - for dt in dim_types - for i in range(obj.dim(dt)) - ] - tgt_names = [ - tgt.get_dim_name(dt, i) - for dt in dim_types - for i in range(tgt.dim(dt)) - ] - - for dt in dim_types: - obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names) - else: - obj_names = [obj.get_dim_name(dt, i) - for dt in dim_types for i in range(obj.dim(dt))] - tgt_names = [tgt.get_dim_name(dt, i) - for dt in dim_types for i in range(tgt.dim(dt))] - - for dt in dim_types: - obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names) + obj_names = [ + obj.get_dim_name(dt, i) + for dt in dim_types + for i in range(obj.dim(dt)) + ] + tgt_names = [ + tgt.get_dim_name(dt, i) + for dt in dim_types + for i in range(tgt.dim(dt)) + ] + + for dt in dim_types: + obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names) return obj diff --git a/test/test_isl.py b/test/test_isl.py index 46855df84c812e036d4832eb7235100ec6216f74..39a7f0f8b6356b0ca97077dbb61a5ae0f26f1f12 100644 --- a/test/test_isl.py +++ b/test/test_isl.py @@ -308,6 +308,14 @@ def test_lexmin(): """).lexmin()) +def test_align_spaces(): + m1 = isl.BasicMap("[m,n] -> {[i,j,k]->[l,o]:}") + m2 = isl.BasicMap("[m,n] -> {[j,k,l,i]->[o]:}") + + result = isl.align_spaces(m1, m2) + assert result.get_var_dict() == m2.get_var_dict() + + if __name__ == "__main__": import sys if len(sys.argv) > 1: