diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
index a70cbaa2eb8c3e057a75e0325003c478a41bae32..d767a083e7ab1c3481ee720440edb019cd07f173 100644
--- a/arraycontext/impl/jax/__init__.py
+++ b/arraycontext/impl/jax/__init__.py
@@ -121,7 +121,7 @@ class EagerJAXArrayContext(ArrayContext):
         return array
 
     def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
-        # TODO: See `jax.experiemental.maps.xmap`, probably that should be useful?
+        # TODO: See `jax.experimental.maps.xmap`, probably that should be useful?
         return array
 
     def call_loopy(self, t_unit, **kwargs):