diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py
index bf5b64ac6428bd013bad39a13c6eebe1a314eecc..d2508f9addc96868edcbcd568d72febf40bf2e11 100644
--- a/pymbolic/geometric_algebra/mapper.py
+++ b/pymbolic/geometric_algebra/mapper.py
@@ -25,16 +25,19 @@ THE SOFTWARE.
 
 # This is experimental, undocumented, and could go away any second.
 # Consider yourself warned.
-
+from collections.abc import Set
 from typing import ClassVar
 
 import pymbolic.geometric_algebra.primitives as prim
 from pymbolic.geometric_algebra import MultiVector
 from pymbolic.mapper import (
     CachedMapper,
+    CollectedT,
     Collector as CollectorBase,
     CombineMapper as CombineMapperBase,
     IdentityMapper as IdentityMapperBase,
+    P,
+    ResultT,
     WalkMapper as WalkMapperBase,
 )
 from pymbolic.mapper.constant_folder import (
@@ -46,50 +49,66 @@ from pymbolic.mapper.stringifier import (
     PREC_NONE,
     StringifyMapper as StringifyMapperBase,
 )
+from pymbolic.primitives import Expression
 
 
-class IdentityMapper(IdentityMapperBase):
-    def map_multivector_variable(self, expr):
+class IdentityMapper(IdentityMapperBase[P]):
+    def map_nabla(
+            self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression:
         return expr
 
-    map_nabla = map_multivector_variable
-    map_nabla_component = map_multivector_variable
+    def map_nabla_component(self,
+            expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression:
+        return expr
 
-    def map_derivative_source(self, expr):
-        operand = self.rec(expr.operand)
+    def map_derivative_source(self,
+                expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
+            ) -> Expression:
+        operand = self.rec(expr.operand, *args, **kwargs)
         if operand is expr.operand:
             return expr
 
         return type(expr)(operand, expr.nabla_id)
 
 
-class CombineMapper(CombineMapperBase):
-    def map_derivative_source(self, expr):
-        return self.rec(expr.operand)
+class CombineMapper(CombineMapperBase[ResultT, P]):
+    def map_derivative_source(
+                self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        return self.rec(expr.operand, *args, **kwargs)
 
 
-class Collector(CollectorBase):
-    def map_nabla(self, expr):
+class Collector(CollectorBase[CollectedT, P]):
+    def map_nabla(self,
+                expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
+            ) -> Set[CollectedT]:
         return set()
 
-    map_nabla_component = map_nabla
+    def map_nabla_component(self,
+                expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
+            ) -> Set[CollectedT]:
+        return set()
 
 
-class WalkMapper(WalkMapperBase):
-    def map_nabla(self, expr, *args):
-        self.visit(expr, *args)
-        self.post_visit(expr)
+class WalkMapper(WalkMapperBase[P]):
+    def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
 
-    def map_nabla_component(self, expr, *args):
-        self.visit(expr, *args)
-        self.post_visit(expr)
+    def map_nabla_component(
+                self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
+            ) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
 
-    def map_derivative_source(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_derivative_source(
+                self, expr, *args: P.args, **kwargs: P.kwargs
+            ) -> None:
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.operand)
-        self.post_visit(expr)
+        self.rec(expr.operand, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
 
 
 class EvaluationMapper(EvaluationMapperBase):
@@ -106,7 +125,7 @@ class EvaluationMapper(EvaluationMapperBase):
         return type(expr)(operand, expr.nabla_id)
 
 
-class StringifyMapper(StringifyMapperBase):
+class StringifyMapper(StringifyMapperBase[[]]):
     AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}
 
     def map_nabla(self, expr, enclosing_prec):