diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 969bc695a06e2075fdba2f2bbc879a2a247794bf..fa09197f5b599753d2b7a25e5d019227d1a0ea6b 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -107,11 +107,14 @@ Pylint:
 
 Mypy:
   script: |
-    EXTRA_INSTALL="jax[cpu]"
+    # NOTE: jax>=0.4.31 requires python 3.10 and uses pattern matching
+    # which conflicts with our mypy.python_version = '3.8' setting
+    EXTRA_INSTALL="mypy pytest jax[cpu]<0.4.31"
+
     curl -L -O https://tiker.net/ci-support-v0
     . ./ci-support-v0
+
     build_py_project_in_venv
-    python -m pip install mypy pytest
     ./run-mypy.sh
   tags:
   - python3
diff --git a/pyproject.toml b/pyproject.toml
index 6c4cdc4af58881cc7289faa830c8e9f32a948717..ca64c70a9b3f5097a422c3d703c001a7a240d3e2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -119,6 +119,7 @@ known-local-folder = [
 lines-after-imports = 2
 
 [tool.mypy]
+# TODO: unpin jax version on CI when this gets bumped to 3.10
 python_version = "3.8"
 warn_unused_ignores = true
 # TODO: enable this