From b355bedead426758c27d32de7ee110a0c3dd251e Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Wed, 14 Aug 2024 13:54:09 -0500
Subject: [PATCH] more unique functions

---
 pytools/__init__.py          | 61 ++++++++++++++++++++++++++++++++----
 pytools/test/test_pytools.py | 27 +++++++++++-----
 2 files changed, 75 insertions(+), 13 deletions(-)

diff --git a/pytools/__init__.py b/pytools/__init__.py
index 25740e2..ff45b90 100644
--- a/pytools/__init__.py
+++ b/pytools/__init__.py
@@ -41,7 +41,6 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
-    Iterator,
     List,
     Mapping,
     Optional,
@@ -211,10 +210,16 @@ String utilities
 .. autofunction:: strtobool
 .. autofunction:: to_identifier
 
-Sequence utilities
-------------------
+Set-like functions for iterables
+--------------------------------
+
+These functions provide set-like operations on iterables. In contrast to
+Python's built-in set type, they maintain the internal order of elements.
 
 .. autofunction:: unique
+.. autofunction:: unique_difference
+.. autofunction:: unique_intersection
+.. autofunction:: unique_union
 
 Type Variables Used
 -------------------
@@ -2991,11 +2996,55 @@ def to_identifier(s: str) -> str:
 
 # {{{ unique
 
-def unique(seq: Iterable[T]) -> Iterator[T]:
-    """Yield unique elements in *seq*, removing all duplicates. The internal
+def unique(seq: Iterable[T]) -> Collection[T]:
+    """Return unique elements in *seq*, removing all duplicates. The internal
     order of the elements is preserved. See also
     :func:`itertools.groupby` (which removes consecutive duplicates)."""
-    return iter(dict.fromkeys(seq))
+    return dict.fromkeys(seq)
+
+
+def unique_difference(*args: Iterable[T]) -> Collection[T]:
+    r"""Return unique elements that are in the first iterable in *\*args* but not
+    in any of the others. The internal order of the elements is preserved."""
+    if not args:
+        return []
+
+    res = dict.fromkeys(args[0])
+    for seq in args[1:]:
+        for item in seq:
+            if item in res:
+                del res[item]
+
+    return res
+
+
+def unique_intersection(*args: Iterable[T]) -> Collection[T]:
+    r"""Return unique elements that are common to all iterables in *\*args*.
+    The internal order of the elements is preserved."""
+    if not args:
+        return []
+
+    res = dict.fromkeys(args[0])
+    for seq in args[1:]:
+        seq = set(seq)
+        res = {item: None for item in res if item in seq}
+
+    return res
+
+
+def unique_union(*args: Iterable[T]) -> Collection[T]:
+    r"""Return unique elements that are in any iterable in *\*args*.
+    The internal order of the elements is preserved."""
+    if not args:
+        return []
+
+    res: Dict[T, None] = {}
+    for seq in args:
+        for item in seq:
+            if item not in res:
+                res[item] = None
+
+    return res
 
 # }}}
 
diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py
index c6f6fda..c262fc9 100644
--- a/pytools/test/test_pytools.py
+++ b/pytools/test/test_pytools.py
@@ -766,7 +766,7 @@ def test_typedump():
 
 
 def test_unique():
-    from pytools import unique
+    from pytools import unique, unique_difference, unique_intersection, unique_union
 
     assert list(unique([1, 2, 1])) == [1, 2]
     assert tuple(unique((1, 2, 1))) == (1, 2)
@@ -774,14 +774,27 @@ def test_unique():
     assert list(range(1000)) == list(unique(range(1000)))
     assert list(unique(list(range(1000)) + list(range(1000)))) == list(range(1000))
 
-    assert next(unique([1, 2, 1, 3])) == 1
-    assert next(unique([]), None) is None
-
     # Also test strings since their ordering would be thrown off by
     # set-based 'unique' implementations.
-    assert list(unique(["A", "B", "A"])) == ["A", "B"]
-    assert tuple(unique(("A", "B", "A"))) == ("A", "B")
-    assert next(unique(["A", "B", "A", "C"])) == "A"
+    assert list(unique(["a", "b", "a"])) == ["a", "b"]
+    assert tuple(unique(("a", "b", "a"))) == ("a", "b")
+
+    assert list(unique_difference(["a", "b", "c"], ["b", "c", "d"])) == ["a"]
+    assert list(unique_difference(["a", "b", "c"], ["a", "b", "c", "d"])) == []
+    assert list(unique_difference(["a", "b", "c"], ["a"], ["b"], ["c"])) == []
+
+    assert list(unique_intersection(["a", "b", "a"], ["b", "c", "a"])) == ["a", "b"]
+    assert list(unique_intersection(["a", "b", "a"], ["d", "c", "e"])) == []
+
+    assert list(unique_union(["a", "b", "a"], ["b", "c", "b"])) == ["a", "b", "c"]
+    assert list(unique_union(
+        ["a", "b", "a"], ["b", "c", "b"], ["c", "d", "c"])) == ["a", "b", "c", "d"]
+    assert list(unique(["a", "b", "a"])) == \
+        list(unique_union(["a", "b", "a"])) == ["a", "b"]
+
+    assert list(unique_intersection()) == []
+    assert list(unique_difference()) == []
+    assert list(unique_union()) == []
 
 
 # This class must be defined globally to be picklable
-- 
GitLab