From 3f848dc44609c56572e8f748d8906ff039dc1192 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 5 Aug 2024 19:23:10 +0300
Subject: [PATCH] ci: pin jax<0.4.31 for mypy on gitlab

---
 .gitlab-ci.yml | 7 +++++--
 pyproject.toml | 1 +
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 969bc69..fa09197 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -107,11 +107,14 @@ Pylint:
 
 Mypy:
   script: |
-    EXTRA_INSTALL="jax[cpu]"
+    # NOTE: jax>=0.4.31 requires python 3.10 and uses pattern matching
+    # which conflicts with our mypy.python_version = '3.8' setting
+    EXTRA_INSTALL="mypy pytest jax[cpu]<0.4.31"
+
     curl -L -O https://tiker.net/ci-support-v0
     . ./ci-support-v0
+
     build_py_project_in_venv
-    python -m pip install mypy pytest
     ./run-mypy.sh
   tags:
   - python3
diff --git a/pyproject.toml b/pyproject.toml
index 6c4cdc4..ca64c70 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -119,6 +119,7 @@ known-local-folder = [
 lines-after-imports = 2
 
 [tool.mypy]
+# TODO: unpin jax version on CI when this gets bumped to 3.10
 python_version = "3.8"
 warn_unused_ignores = true
 # TODO: enable this
-- 
GitLab