From 36af97d97ab944b7a049ef83843152111db0ae22 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 31 Oct 2022 12:33:58 -0500
Subject: [PATCH] Support PEP634-style pattern matching

---
 .github/workflows/ci.yml   |  5 +++-
 pymbolic/primitives.py     | 10 ++++++++
 test/test_pattern_match.py | 51 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 1 deletion(-)
 create mode 100644 test/test_pattern_match.py

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ef12827..02e7785 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -23,6 +23,8 @@ jobs:
         -   name: "Main Script"
             run: |
                 curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh
+                # FIXME Remove when we're upgrading to Python 3.10.
+                rm test/test_pattern_match.py
                 . ./prepare-and-run-flake8.sh pymbolic test experiments
 
     pylint:
@@ -69,8 +71,9 @@ jobs:
                 # https://github.com/inducer/pymbolic/pull/66#issuecomment-950371315
                 pip install symengine || true
 
-                test_py_project
+                python3 -c 'import sys, os; sys.exit(not sys.version_info < (3, 10))' && rm test/test_pattern_match.py
 
+                test_py_project
 
     docs:
         name: Documentation
diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index 5ad5eb0..f5d7280 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -187,6 +187,11 @@ class Expression(ABC):
 
     Expression objects are immutable.
 
+    .. versionchanged:: 2022.2
+
+        `PEP 634 <https://peps.python.org/pep-0634/>`__-style pattern matching
+        is now supported when Pymbolic is used under Python 3.10.
+
     .. attribute:: a
 
     .. attribute:: attr
@@ -231,6 +236,11 @@ class Expression(ABC):
     def __getinitargs__(self):
         pass
 
+    @classmethod
+    @property
+    def __match_args__(cls):
+        return cls.init_arg_names
+
     @property
     def init_arg_names(self):
         raise NotImplementedError
diff --git a/test/test_pattern_match.py b/test/test_pattern_match.py
new file mode 100644
index 0000000..a28392d
--- /dev/null
+++ b/test/test_pattern_match.py
@@ -0,0 +1,51 @@
+__copyright__ = "Copyright (C) 2022 University of Illinois Board of Trustees"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import pymbolic.primitives as p
+
+
+def test_pattern_match():
+    from pymbolic import var
+
+    x = var("x")
+    xp1 = (x+1)
+    u = xp1**5 + x
+
+    match u:
+        case p.Sum((p.Power(base, exp), other_term)):
+            assert base is xp1
+            assert exp == 5
+            assert other_term is x
+
+        case _:
+            raise AssertionError()
+
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) > 1:
+        exec(sys.argv[1])
+    else:
+        from pytest import main
+        main([__file__])
+
+# vim: fdm=marker
-- 
GitLab