diff --git a/islpy/__init__.py b/islpy/__init__.py
index cd6ef6df5a7bf44bd3d93b3d5b450c9db8ecc8f6..45af7505ae407054881777af397e3a3224f4c827 100644
--- a/islpy/__init__.py
+++ b/islpy/__init__.py
@@ -1173,7 +1173,7 @@ def _align_dim_type(tgt_dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names):
     return obj
 
 
-def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=False):
+def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=None):
     """
     Try to make the space in which *obj* lives the same as that of *tgt* by
     adding/matching named dimensions.
@@ -1182,6 +1182,11 @@ def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=False):
         has more dimensions than *tgt*.
     """
 
+    if across_dim_types is not None:
+        warn("across_dim_types is deprecated and should no longer be used. "
+                "It never had any effect anyway.",
+                DeprecationWarning, stacklevel=2)
+
     have_any_param_domains = (
             isinstance(obj, (Set, BasicSet))
             and isinstance(tgt, (Set, BasicSet))
@@ -1198,28 +1203,19 @@ def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=False):
     else:
         dim_types = _CHECK_DIM_TYPES
 
-    if across_dim_types:
-        obj_names = [
-                obj.get_dim_name(dt, i)
-                for dt in dim_types
-                for i in range(obj.dim(dt))
-                ]
-        tgt_names = [
-                tgt.get_dim_name(dt, i)
-                for dt in dim_types
-                for i in range(tgt.dim(dt))
-                ]
-
-        for dt in dim_types:
-            obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names)
-    else:
-        obj_names = [obj.get_dim_name(dt, i)
-                for dt in dim_types for i in range(obj.dim(dt))]
-        tgt_names = [tgt.get_dim_name(dt, i)
-                for dt in dim_types for i in range(tgt.dim(dt))]
-
-        for dt in dim_types:
-            obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names)
+    obj_names = [
+            obj.get_dim_name(dt, i)
+            for dt in dim_types
+            for i in range(obj.dim(dt))
+            ]
+    tgt_names = [
+            tgt.get_dim_name(dt, i)
+            for dt in dim_types
+            for i in range(tgt.dim(dt))
+            ]
+
+    for dt in dim_types:
+        obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names)
 
     return obj
 
diff --git a/test/test_isl.py b/test/test_isl.py
index 46855df84c812e036d4832eb7235100ec6216f74..39a7f0f8b6356b0ca97077dbb61a5ae0f26f1f12 100644
--- a/test/test_isl.py
+++ b/test/test_isl.py
@@ -308,6 +308,14 @@ def test_lexmin():
             """).lexmin())
 
 
+def test_align_spaces():
+    m1 = isl.BasicMap("[m,n] -> {[i,j,k]->[l,o]:}")
+    m2 = isl.BasicMap("[m,n] -> {[j,k,l,i]->[o]:}")
+
+    result = isl.align_spaces(m1, m2)
+    assert result.get_var_dict() == m2.get_var_dict()
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1: