diff --git a/islpy/__init__.py b/islpy/__init__.py index d5dd79ad6af5abd053f6be4cfd31cda0be376c9e..6f54ba1ca8f0a088165678c28cfa3d659c90bb7d 100644 --- a/islpy/__init__.py +++ b/islpy/__init__.py @@ -245,45 +245,6 @@ def _add_functionality(): # }}} - # {{{ rich comparisons - - def obj_eq(self, other): - assert self.get_ctx() == other.get_ctx(), ( - "Equality-comparing two objects from different ISL Contexts " - "will likely lead to entertaining (but never useful) results. " - "In particular, Spaces with matching names will no longer be " - "equal.") - - return self.is_equal(other) - - def obj_ne(self, other): - return not self.__eq__(other) - - for cls in ALL_CLASSES: - if hasattr(cls, "is_equal"): - cls.__eq__ = obj_eq - cls.__ne__ = obj_ne - - def obj_lt(self, other): - return self.is_strict_subset(other) - - def obj_le(self, other): - return self.is_subset(other) - - def obj_gt(self, other): - return other.is_strict_subset(self) - - def obj_ge(self, other): - return other.is_subset(self) - - for cls in [BasicSet, BasicMap, Set, Map]: - cls.__lt__ = obj_lt - cls.__le__ = obj_le - cls.__gt__ = obj_gt - cls.__ge__ = obj_ge - - # }}} - # {{{ Python set-like behavior def obj_or(self, other): @@ -976,6 +937,49 @@ def _add_functionality(): # }}} + # ORDERING DEPENDENCY: The availability of some of the 'is_equal' + # used by rich comparison below depends on the self upcasts created + # above. + + # {{{ rich comparisons + + def obj_eq(self, other): + assert self.get_ctx() == other.get_ctx(), ( + "Equality-comparing two objects from different ISL Contexts " + "will likely lead to entertaining (but never useful) results. " + "In particular, Spaces with matching names will no longer be " + "equal.") + + return self.is_equal(other) + + def obj_ne(self, other): + return not self.__eq__(other) + + for cls in ALL_CLASSES: + if hasattr(cls, "is_equal"): + cls.__eq__ = obj_eq + cls.__ne__ = obj_ne + + def obj_lt(self, other): + return self.is_strict_subset(other) + + def obj_le(self, other): + return self.is_subset(other) + + def obj_gt(self, other): + return other.is_strict_subset(self) + + def obj_ge(self, other): + return other.is_subset(self) + + for cls in [BasicSet, BasicMap, Set, Map]: + cls.__lt__ = obj_lt + cls.__le__ = obj_le + cls.__gt__ = obj_gt + cls.__ge__ = obj_ge + + # }}} + # {{{ project_out_except def obj_project_out_except(obj, names, types): @@ -1102,47 +1106,56 @@ def _set_dim_id(obj, dt, idx, id): return _back_to_basic(obj.set_dim_id(dt, idx, id), obj) -def _align_dim_type(tgt_dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names): +def _align_dim_type(template_dt, obj, template, obj_bigger_ok, obj_names, + template_names): + + # {{{ deal with Aff, PwAff + # The technique below will not work for PwAff et al, because there is *only* # the 'param' dim_type, and we are not allowed to move dims around in there. # We'll make isl do the work, using align_params. - if tgt_dt == dim_type.param and isinstance(obj, (Aff, PwAff)): - if not isinstance(tgt, Space): - tgt_space = tgt.space + if template_dt == dim_type.param and isinstance(obj, (Aff, PwAff)): + if not isinstance(template, Space): + template_space = template.space else: - tgt_space = tgt - if (not obj_bigger_ok - or obj.space.dim(dim_type.param) == tgt_space.dim(dim_type.param)): - return obj.align_params(tgt_space) - - if None in tgt_names: - all_nones = [None] * len(tgt_names) - if tgt_names == all_nones and obj_names == all_nones: + template_space = template + + if not obj_bigger_ok: + if (obj.dim(template_dt) > template.dim(template_dt) + or not set(obj.get_var_dict()) <= set(template.get_var_dict())): + raise Error("obj has leftover dimensions after alignment") + return obj.align_params(template_space) + + # }}} + + if None in template_names: + all_nones = [None] * len(template_names) + if template_names == all_nones and obj_names == all_nones: # that's ok return obj - raise RuntimeError("tgt may not contain any unnamed dimensions") + raise Error("template may not contain any unnamed dimensions") obj_names = set(obj_names) - set([None]) - tgt_names = set(tgt_names) - set([None]) + template_names = set(template_names) - set([None]) - names_in_both = obj_names & tgt_names + names_in_both = obj_names & template_names tgt_idx = 0 - while tgt_idx < tgt.dim(tgt_dt): - tgt_id = tgt.get_dim_id(tgt_dt, tgt_idx) + while tgt_idx < template.dim(template_dt): + tgt_id = template.get_dim_id(template_dt, tgt_idx) tgt_name = tgt_id.name if tgt_name in names_in_both: - if (obj.dim(tgt_dt) > tgt_idx - and tgt_name == obj.get_dim_name(tgt_dt, tgt_idx)): + if (obj.dim(template_dt) > tgt_idx + and tgt_name == obj.get_dim_name(template_dt, tgt_idx)): pass else: src_dt, src_idx = obj.get_var_dict()[tgt_name] - if src_dt == tgt_dt: + if src_dt == template_dt: assert src_idx > tgt_idx # isl requires move_dims to be between different types. @@ -1153,33 +1166,34 @@ def _align_dim_type(tgt_dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names): other_dt_dim = obj.dim(other_dt) obj = obj.move_dims(other_dt, other_dt_dim, src_dt, src_idx, 1) - obj = obj.move_dims(tgt_dt, tgt_idx, other_dt, other_dt_dim, 1) + obj = obj.move_dims( + template_dt, tgt_idx, other_dt, other_dt_dim, 1) else: - obj = obj.move_dims(tgt_dt, tgt_idx, src_dt, src_idx, 1) + obj = obj.move_dims(template_dt, tgt_idx, src_dt, src_idx, 1) # names are same, make Ids the same, too - obj = _set_dim_id(obj, tgt_dt, tgt_idx, tgt_id) + obj = _set_dim_id(obj, template_dt, tgt_idx, tgt_id) tgt_idx += 1 else: - obj = obj.insert_dims(tgt_dt, tgt_idx, 1) - obj = _set_dim_id(obj, tgt_dt, tgt_idx, tgt_id) + obj = obj.insert_dims(template_dt, tgt_idx, 1) + obj = _set_dim_id(obj, template_dt, tgt_idx, tgt_id) tgt_idx += 1 - if tgt_idx < obj.dim(tgt_dt) and not obj_bigger_ok: - raise ValueError("obj has leftover dimensions") + if tgt_idx < obj.dim(template_dt) and not obj_bigger_ok: + raise Error("obj has leftover dimensions after alignment") return obj -def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=None): +def align_spaces(obj, template, obj_bigger_ok=False, across_dim_types=None): """ - Try to make the space in which *obj* lives the same as that of *tgt* by + Try to make the space in which *obj* lives the same as that of *template* by adding/matching named dimensions. :param obj_bigger_ok: If *True*, no error is raised if the resulting *obj* - has more dimensions than *tgt*. + has more dimensions than *template*. """ if across_dim_types is not None: @@ -1190,15 +1204,15 @@ def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=None): have_any_param_domains = ( isinstance(obj, (Set, BasicSet)) - and isinstance(tgt, (Set, BasicSet)) - and (obj.is_params() or tgt.is_params())) + and isinstance(template, (Set, BasicSet)) + and (obj.is_params() or template.is_params())) if have_any_param_domains: if obj.is_params(): obj = type(obj).from_params(obj) - if tgt.is_params(): - tgt = type(tgt).from_params(tgt) + if template.is_params(): + template = type(template).from_params(template) - if isinstance(tgt, EXPR_CLASSES): + if isinstance(template, EXPR_CLASSES): dim_types = _CHECK_DIM_TYPES[:] dim_types.remove(dim_type.out) else: @@ -1209,14 +1223,15 @@ def align_spaces(obj, tgt, obj_bigger_ok=False, across_dim_types=None): for dt in dim_types for i in range(obj.dim(dt)) ] - tgt_names = [ - tgt.get_dim_name(dt, i) + template_names = [ + template.get_dim_name(dt, i) for dt in dim_types - for i in range(tgt.dim(dt)) + for i in range(template.dim(dt)) ] for dt in dim_types: - obj = _align_dim_type(dt, obj, tgt, obj_bigger_ok, obj_names, tgt_names) + obj = _align_dim_type( + dt, obj, template, obj_bigger_ok, obj_names, template_names) return obj diff --git a/test/test_isl.py b/test/test_isl.py index 9e18953878a8df2912404c45326678ab6ece074d..da461ceb48760d5f7fd0f422320435513dbd5a83 100644 --- a/test/test_isl.py +++ b/test/test_isl.py @@ -309,6 +309,18 @@ def test_align_spaces(): result = isl.align_spaces(m1, m2) assert result.get_var_dict() == m2.get_var_dict() + a1 = isl.Aff("[t0, t1, t2] -> { [(32)] }") + a2 = isl.Aff("[t1, t0] -> { [(0)] }") + + with pytest.raises(isl.Error): + a1_aligned = isl.align_spaces(a1, a2) + + a1_aligned = isl.align_spaces(a1, a2, obj_bigger_ok=True) + a2_aligned = isl.align_spaces(a2, a1) + + assert a1_aligned == isl.Aff("[t1, t0, t2] -> { [(32)] }") + assert a2_aligned == isl.Aff("[t1, t0, t2] -> { [(0)] }") + def test_pass_numpy_int(): np = pytest.importorskip("numpy") @@ -320,6 +332,22 @@ def test_pass_numpy_int(): print(c1) +def test_isl_align_two(): + a1 = isl.Aff("[t0, t1, t2] -> { [(32)] }") + a2 = isl.Aff("[t1, t0] -> { [(0)] }") + + a1_aligned, a2_aligned = isl.align_two(a1, a2) + assert a1_aligned == isl.Aff("[t1, t0, t2] -> { [(32)] }") + assert a2_aligned == isl.Aff("[t1, t0, t2] -> { [(0)] }") + + b1 = isl.BasicSet("[n0, n1, n2] -> { [i0, i1] : }") + b2 = isl.BasicSet("[n0, n2, n1, n3] -> { [i1, i0, i2] : }") + + b1_aligned, b2_aligned = isl.align_two(b1, b2) + assert b1_aligned == isl.BasicSet("[n0, n2, n1, n3] -> { [i1, i0, i2] : }") + assert b2_aligned == isl.BasicSet("[n0, n2, n1, n3] -> { [i1, i0, i2] : }") + + if __name__ == "__main__": import sys if len(sys.argv) > 1: