diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py
index 743edf6ef2a60419cf3107ad5a2612b2ec0c441c..4b6cc1cc2072cef40bdc4c065eb8cd4bd8c67b94 100644
--- a/pymbolic/interop/ast.py
+++ b/pymbolic/interop/ast.py
@@ -31,7 +31,7 @@ from typing import Any, ClassVar
 
 import pymbolic.primitives as p
 from pymbolic.mapper import CachedMapper
-from pymbolic.typing import ExpressionT, ScalarT
+from pymbolic.typing import ExpressionT
 
 
 __doc__ = r'''
@@ -263,7 +263,7 @@ class ASTToPymbolic(ASTMapper):
 
 # {{{ PymbolicToASTMapper
 
-class PymbolicToASTMapper(CachedMapper):
+class PymbolicToASTMapper(CachedMapper[ast.expr, []]):
     def map_variable(self, expr) -> ast.expr:
         return ast.Name(id=expr.name)
 
@@ -283,7 +283,7 @@ class PymbolicToASTMapper(CachedMapper):
     def map_product(self, expr: p.Product) -> ast.expr:
         return self._map_multi_children_op(expr.children, ast.Mult())
 
-    def map_constant(self, expr: ScalarT) -> ast.expr:
+    def map_constant(self, expr: object) -> ast.expr:
         if isinstance(expr, bool):
             return ast.NameConstant(expr)
         else:
@@ -393,7 +393,7 @@ class PymbolicToASTMapper(CachedMapper):
             raise NotImplementedError("Non-float nan not implemented")
 
     def map_slice(self, expr: p.Slice) -> ast.expr:
-        return ast.Slice(*[self.rec(child)
+        return ast.Slice(*[None if child is None else self.rec(child)
                            for child in expr.children])
 
     def map_numpy_array(self, expr) -> ast.expr:
diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py
index 22315f901ee450769eea811011c188d336ec5969..65dbfb923b9944ed7ad2a6eceecd55725c406e29 100644
--- a/pymbolic/mapper/__init__.py
+++ b/pymbolic/mapper/__init__.py
@@ -24,17 +24,41 @@ THE SOFTWARE.
 """
 
 from abc import ABC, abstractmethod
-from typing import Any
+from collections.abc import Mapping
+from typing import (
+    TYPE_CHECKING,
+    AbstractSet,
+    Callable,
+    Generic,
+    Hashable,
+    Iterable,
+    TypeVar,
+    cast,
+)
+from warnings import warn
 
 from immutabledict import immutabledict
+from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeIs
 
-import pymbolic.primitives as primitives
+import pymbolic.primitives as p
+from pymbolic.typing import ArithmeticExpressionT, ExpressionT
+
+
+if TYPE_CHECKING:
+    import numpy as np
+
+    from pymbolic.geometric_algebra import MultiVector
+    from pymbolic.rational import Rational
 
 
 __doc__ = """
 Basic dispatch
 --------------
 
+.. class:: ResultT
+
+    A type variable for the result returned by a :class:`Mapper`.
+
 .. autoclass:: Mapper
 
     .. automethod:: __call__
@@ -96,14 +120,20 @@ Base classes for mappers with memoization support
 """
 
 
-try:
-    import numpy
+if TYPE_CHECKING:
+    import numpy as np
 
-    def is_numpy_array(val):
-        return isinstance(val, numpy.ndarray)
-except ImportError:
-    def is_numpy_array(ary):
-        return False
+    def is_numpy_array(val) -> TypeIs[np.ndarray]:
+        return isinstance(val, np.ndarray)
+else:
+    try:
+        import numpy as np
+
+        def is_numpy_array(val):
+            return isinstance(val, np.ndarray)
+    except ImportError:
+        def is_numpy_array(ary):
+            return False
 
 
 class UnsupportedExpressionError(ValueError):
@@ -112,15 +142,28 @@ class UnsupportedExpressionError(ValueError):
 
 # {{{ mapper base
 
-class Mapper:
+ResultT = TypeVar("ResultT")
+
+# This ParamSpec could be marked contravariant (just like Callable is contravariant
+# in its arguments). As of mypy 1.14/Py3.13 (Nov 2024), mypy complains of as-yet
+# undefined semantics, so it's probably too soon.
+P = ParamSpec("P")
+
+
+class Mapper(Generic[ResultT, P]):
     """A visitor for trees of :class:`pymbolic.Expression`
     subclasses. Each expression-derived object is dispatched to the
     method named by the :attr:`pymbolic.Expression.mapper_method`
     attribute and if not found, the methods named by the class attribute
     *mapper_method* in the method resolution order of the object.
+
+    ..automethod:: handle_unsupported_expression
+    ..automethod:: __call__
+    ..automethod:: rec
     """
 
-    def handle_unsupported_expression(self, expr, *args, **kwargs):
+    def handle_unsupported_expression(self,
+            expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         """Mapper method that is invoked for
         :class:`pymbolic.Expression` subclasses for which a mapper
         method does not exist in this mapper.
@@ -130,7 +173,8 @@ class Mapper:
                 "{} cannot handle expressions of type {}".format(
                     type(self), type(expr)))
 
-    def __call__(self, expr, *args, **kwargs):
+    def __call__(self,
+             expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         """Dispatch *expr* to its corresponding mapper method. Pass on
         ``*args`` and ``**kwargs`` unmodified.
 
@@ -148,7 +192,7 @@ class Mapper:
                 result = method(expr, *args, **kwargs)
                 return result
 
-        if isinstance(expr, primitives.Expression):
+        if isinstance(expr, p.Expression):
             for cls in type(expr).__mro__[1:]:
                 method_name = getattr(cls, "mapper_method", None)
                 if method_name:
@@ -162,8 +206,9 @@ class Mapper:
 
     rec = __call__
 
-    def rec_fallback(self, expr, *args, **kwargs):
-        if isinstance(expr, primitives.Expression):
+    def rec_fallback(self,
+            expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        if isinstance(expr, p.Expression):
             for cls in type(expr).__mro__[1:]:
                 method_name = getattr(cls, "mapper_method", None)
                 if method_name:
@@ -175,76 +220,188 @@ class Mapper:
         else:
             return self.map_foreign(expr, *args, **kwargs)
 
-    def map_algebraic_leaf(self, expr, *args, **kwargs):
+    def map_algebraic_leaf(self,
+            expr: p.AlgebraicLeaf,
+            *args: P.args, **kwargs: P.kwargs) -> ResultT:
         raise NotImplementedError
 
-    def map_variable(self, expr, *args, **kwargs):
+    def map_variable(self,
+            expr: p.Variable, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(self,
+            expr: p.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_call(self, expr, *args, **kwargs):
+    def map_call(self,
+            expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_lookup(self, expr, *args, **kwargs):
+    def map_call_with_kwargs(self,
+            expr: p.CallWithKwargs,
+            *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_if_positive(self, expr, *args, **kwargs):
+    def map_lookup(self,
+            expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_rational(self, expr, *args, **kwargs):
-        return self.map_quotient(expr, *args, **kwargs)
+    def map_if(self,
+            expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
 
-    def map_quotient(self, expr, *args, **kwargs):
+    def map_sum(self,
+            expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         raise NotImplementedError
 
-    def map_constant(self, expr, *args, **kwargs):
+    def map_product(self,
+            expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_rational(self,
+            expr: Rational, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_quotient(self,
+            expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_floor_div(self,
+            expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_remainder(self,
+            expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_constant(self,
+            expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_comparison(self,
+            expr: p.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_min(self,
+            expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_max(self,
+            expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_list(self,
+            expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_tuple(self,
+            expr: tuple[ExpressionT, ...],
+            *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_numpy_array(self,
+            expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        raise NotImplementedError
+
+    def map_left_shift(self,
+            expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         raise NotImplementedError
 
-    def map_list(self, expr, *args, **kwargs):
+    def map_right_shift(self,
+                expr: p.RightShift, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         raise NotImplementedError
 
-    def map_tuple(self, expr, *args, **kwargs):
+    def map_bitwise_not(self,
+                expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         raise NotImplementedError
 
-    def map_numpy_array(self, expr, *args, **kwargs):
+    def map_bitwise_or(self,
+                expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         raise NotImplementedError
 
-    def map_nan(self, expr, *args, **kwargs):
+    def map_bitwise_and(self,
+                expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        raise NotImplementedError
+
+    def map_bitwise_xor(self,
+                expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        raise NotImplementedError
+
+    def map_logical_not(self,
+                expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        raise NotImplementedError
+
+    def map_logical_or(self,
+                expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        raise NotImplementedError
+
+    def map_logical_and(self,
+                expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        raise NotImplementedError
+
+    def map_nan(self,
+                expr: p.NaN,
+                *args: P.args,
+                **kwargs: P.kwargs
+            ) -> ResultT:
         return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_foreign(self, expr, *args, **kwargs):
+    def map_foreign(self,
+                expr: object,
+                *args: P.args,
+                **kwargs: P.kwargs
+            ) -> ResultT:
         """Mapper method dispatch for non-:mod:`pymbolic` objects."""
 
-        if isinstance(expr, primitives.VALID_CONSTANT_CLASSES):
+        if isinstance(expr, p.VALID_CONSTANT_CLASSES):
             return self.map_constant(expr, *args, **kwargs)
         elif is_numpy_array(expr):
             return self.map_numpy_array(expr, *args, **kwargs)
-        elif isinstance(expr, list):
-            return self.map_list(expr, *args, **kwargs)
         elif isinstance(expr, tuple):
             return self.map_tuple(expr, *args, **kwargs)
+        elif isinstance(expr, list):
+            warn("List found in expression graph. "
+                 "This is deprecated and will stop working in 2025. "
+                 "Use tuples instead.", DeprecationWarning, stacklevel=2
+             )
+            return self.map_list(expr, *args, **kwargs)
         else:
             raise ValueError(
                     "{} encountered invalid foreign object: {}".format(
                         self.__class__, repr(expr)))
 
 
-_NOT_IN_CACHE = object()
+class _NotInCache:
+    pass
+
 
+CacheKeyT: TypeAlias = Hashable
 
-class CachedMapper(Mapper):
+
+class CachedMapper(Mapper[ResultT, P]):
     """
     A mapper that memoizes the mapped result for the expressions traversed.
 
     .. automethod:: get_cache_key
     """
-    def __init__(self):
-        self._cache: dict[Any, Any] = {}
+    def __init__(self) -> None:
+        self._cache: dict[CacheKeyT, ResultT] = {}
         Mapper.__init__(self)
 
-    def get_cache_key(self, expr, *args, **kwargs):
+    def get_cache_key(self,
+              expr: ExpressionT,
+              *args: P.args,
+              **kwargs: P.kwargs
+          ) -> CacheKeyT:
         """
         Returns the key corresponding to which the result of a mapper method is
         stored in the cache.
@@ -260,16 +417,23 @@ class CachedMapper(Mapper):
         # and "4 == 4.0", but their traversal results cannot be re-used.
         return (type(expr), expr, args, immutabledict(kwargs))
 
-    def __call__(self, expr, *args, **kwargs):
+    def __call__(self,
+                 expr: ExpressionT,
+                 *args: P.args,
+                 **kwargs: P.kwargs
+             ) -> ResultT:
         result = self._cache.get(
                 (cache_key := self.get_cache_key(expr, *args, **kwargs)),
-                _NOT_IN_CACHE)
-        if result is not _NOT_IN_CACHE:
+                _NotInCache)
+        if not isinstance(result, type):
             return result
 
         method_name = getattr(expr, "mapper_method", None)
         if method_name is not None:
-            method = getattr(self, method_name, None)
+            method = cast(
+                Callable[Concatenate[ExpressionT, P], ResultT],
+                getattr(self, method_name, None)
+                )
             if method is not None:
                 result = method(expr, *args, **kwargs)
                 self._cache[cache_key] = result
@@ -286,7 +450,7 @@ class CachedMapper(Mapper):
 
 # {{{ combine mapper
 
-class CombineMapper(Mapper):
+class CombineMapper(Mapper[ResultT, P]):
     """A mapper whose goal it is to *combine* all branches of the expression
     tree into one final result. The default implementation of all mapper
     methods simply recurse (:meth:`Mapper.rec`) on all branches emanating from
@@ -304,16 +468,19 @@ class CombineMapper(Mapper):
     :class:`pymbolic.mapper.dependency.DependencyMapper` is another example.
     """
 
-    def combine(self, values):
+    def combine(self, values: Iterable[ResultT]) -> ResultT:
         raise NotImplementedError
 
-    def map_call(self, expr, *args, **kwargs):
+    def map_call(self,
+            expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
             self.rec(expr.function, *args, **kwargs),
             *[self.rec(child, *args, **kwargs) for child in expr.parameters]
             ))
 
-    def map_call_with_kwargs(self, expr, *args, **kwargs):
+    def map_call_with_kwargs(self,
+            expr: p.CallWithKwargs,
+            *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
             self.rec(expr.function, *args, **kwargs),
             *[self.rec(child, *args, **kwargs) for child in expr.parameters],
@@ -321,87 +488,141 @@ class CombineMapper(Mapper):
               for child in expr.kw_parameters.values()]
             ))
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(self,
+            expr: p.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine(
                 [self.rec(expr.aggregate, *args, **kwargs),
                     self.rec(expr.index, *args, **kwargs)])
 
-    def map_lookup(self, expr, *args, **kwargs):
+    def map_lookup(self,
+            expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.rec(expr.aggregate, *args, **kwargs)
 
-    def map_sum(self, expr, *args, **kwargs):
+    def map_sum(self,
+            expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine(self.rec(child, *args, **kwargs)
                 for child in expr.children)
 
-    map_product = map_sum
+    def map_product(self,
+            expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
 
-    def map_quotient(self, expr, *args, **kwargs):
+    def map_quotient(self,
+            expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
             self.rec(expr.numerator, *args, **kwargs),
             self.rec(expr.denominator, *args, **kwargs)))
 
-    map_floor_div = map_quotient
-    map_remainder = map_quotient
+    def map_floor_div(self,
+            expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine((
+            self.rec(expr.numerator, *args, **kwargs),
+            self.rec(expr.denominator, *args, **kwargs)))
+
+    def map_remainder(self,
+            expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine((
+            self.rec(expr.numerator, *args, **kwargs),
+            self.rec(expr.denominator, *args, **kwargs)))
 
-    def map_power(self, expr, *args, **kwargs):
+    def map_power(self,
+            expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
                 self.rec(expr.base, *args, **kwargs),
                 self.rec(expr.exponent, *args, **kwargs)))
 
-    def map_polynomial(self, expr, *args, **kwargs):
+    def map_left_shift(self,
+            expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
-            self.rec(expr.base, *args, **kwargs),
-            *[self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data]
-            ))
+            self.rec(expr.shiftee, *args, **kwargs),
+            self.rec(expr.shift, *args, **kwargs)))
 
-    def map_left_shift(self, expr, *args, **kwargs):
+    def map_right_shift(self,
+            expr: p.RightShift, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
             self.rec(expr.shiftee, *args, **kwargs),
             self.rec(expr.shift, *args, **kwargs)))
 
-    map_right_shift = map_left_shift
+    def map_bitwise_not(self,
+            expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.rec(expr.child, *args, **kwargs)
 
-    def map_bitwise_not(self, expr, *args, **kwargs):
+    def map_bitwise_or(self,
+            expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
+
+    def map_bitwise_and(self,
+            expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
+
+    def map_bitwise_xor(self,
+            expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
+
+    def map_logical_not(self,
+            expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.rec(expr.child, *args, **kwargs)
-    map_bitwise_or = map_sum
-    map_bitwise_xor = map_sum
-    map_bitwise_and = map_sum
 
-    map_logical_not = map_bitwise_not
-    map_logical_and = map_sum
-    map_logical_or = map_sum
+    def map_logical_or(self,
+            expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
+
+    def map_logical_and(self,
+            expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
 
-    def map_comparison(self, expr, *args, **kwargs):
+    def map_comparison(self,
+            expr: p.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine((
             self.rec(expr.left, *args, **kwargs),
             self.rec(expr.right, *args, **kwargs)))
 
-    map_max = map_sum
-    map_min = map_sum
+    def map_max(self,
+            expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
+
+    def map_min(self,
+            expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs)
+                for child in expr.children)
 
-    def map_list(self, expr, *args, **kwargs):
+    def map_tuple(self,
+                expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         return self.combine(self.rec(child, *args, **kwargs) for child in expr)
 
-    map_tuple = map_list
+    def map_list(self,
+                expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
+        return self.combine(self.rec(child, *args, **kwargs) for child in expr)
 
-    def map_numpy_array(self, expr, *args, **kwargs):
+    def map_numpy_array(self,
+                expr: np.ndarray, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat)
 
-    def map_multivector(self, expr, *args, **kwargs):
+    def map_multivector(self,
+                expr: MultiVector[ArithmeticExpressionT],
+                *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         return self.combine(
                 self.rec(coeff, *args, **kwargs)
                 for bits, coeff in expr.data.items())
 
-    def map_common_subexpression(self, expr, *args, **kwargs):
+    def map_common_subexpression(self,
+                expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs
+            ) -> ResultT:
         return self.rec(expr.child, *args, **kwargs)
 
-    def map_if_positive(self, expr, *args, **kwargs):
-        return self.combine([
-            self.rec(expr.criterion, *args, **kwargs),
-            self.rec(expr.then, *args, **kwargs),
-            self.rec(expr.else_, *args, **kwargs)])
-
-    def map_if(self, expr, *args, **kwargs):
+    def map_if(self,
+            expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ResultT:
         return self.combine([
             self.rec(expr.condition, *args, **kwargs),
             self.rec(expr.then, *args, **kwargs),
@@ -416,7 +637,10 @@ class CachedCombineMapper(CachedMapper, CombineMapper):
 
 # {{{ collector
 
-class Collector(CombineMapper):
+CollectedT = TypeVar("CollectedT")
+
+
+class Collector(CombineMapper[AbstractSet[CollectedT], P]):
     """A subclass of :class:`CombineMapper` for the common purpose of
     collecting data derived from an expression in a set that gets 'unioned'
     across children at each non-leaf node in the expression tree.
@@ -426,19 +650,36 @@ class Collector(CombineMapper):
     .. versionadded:: 2014.3
     """
 
-    def combine(self, values):
+    def combine(self,
+                values: Iterable[AbstractSet[CollectedT]]
+            ) -> AbstractSet[CollectedT]:
         import operator
         from functools import reduce
         return reduce(operator.or_, values, set())
 
-    def map_constant(self, expr, *args, **kwargs):
+    def map_constant(self, expr: object,
+                     *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]:
         return set()
 
-    map_variable = map_constant
-    map_wildcard = map_constant
-    map_dot_wildcard = map_constant
-    map_star_wildcard = map_constant
-    map_function_symbol = map_constant
+    def map_variable(self, expr: p.Variable,
+                     *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]:
+        return set()
+
+    def map_wildcard(self, expr: p.Wildcard,
+                     *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]:
+        return set()
+
+    def map_dot_wildcard(self, expr: p.DotWildcard,
+                     *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]:
+        return set()
+
+    def map_star_wildcard(self, expr: p.StarWildcard,
+                     *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]:
+        return set()
+
+    def map_function_symbol(self, expr: p.FunctionSymbol,
+                     *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]:
+        return set()
 
 
 class CachedCollector(CachedMapper, Collector):
@@ -449,34 +690,59 @@ class CachedCollector(CachedMapper, Collector):
 
 # {{{ identity mapper
 
-class IdentityMapper(Mapper):
+class IdentityMapper(Mapper[ExpressionT, P]):
     """A :class:`Mapper` whose default mapper methods
     make a deep copy of each subexpression.
 
     See :ref:`custom-manipulation` for an example of the
     manipulations that can be implemented this way.
+
+    .. automethod:: rec_arith
     """
-    def map_constant(self, expr, *args, **kwargs):
+
+    def rec_arith(self,
+                expr: ArithmeticExpressionT, *args: P.args, **kwargs: P.kwargs
+            ) -> ArithmeticExpressionT:
+        res = self.rec(expr, *args, **kwargs)
+        assert p.is_arithmetic_expression(res)
+        return res
+
+    def map_constant(self,
+                expr: object, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         # leaf -- no need to rebuild
+        assert p.is_valid_operand(expr)
         return expr
 
-    def map_variable(self, expr, *args, **kwargs):
+    def map_variable(self,
+                expr: p.Variable, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         # leaf -- no need to rebuild
         return expr
 
-    def map_wildcard(self, expr, *args, **kwargs):
+    def map_wildcard(self,
+                expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         return expr
 
-    def map_dot_wildcard(self, expr, *args, **kwargs):
+    def map_dot_wildcard(self,
+                expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         return expr
 
-    def map_star_wildcard(self, expr, *args, **kwargs):
+    def map_star_wildcard(self,
+                expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         return expr
 
-    def map_function_symbol(self, expr, *args, **kwargs):
+    def map_function_symbol(self,
+                expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         return expr
 
-    def map_call(self, expr, *args, **kwargs):
+    def map_call(self,
+                expr: p.Call, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         function = self.rec(expr.function, *args, **kwargs)
         parameters = tuple([
             self.rec(child, *args, **kwargs) for child in expr.parameters
@@ -488,12 +754,14 @@ class IdentityMapper(Mapper):
 
         return type(expr)(function, parameters)
 
-    def map_call_with_kwargs(self, expr, *args, **kwargs):
+    def map_call_with_kwargs(self,
+                expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         function = self.rec(expr.function, *args, **kwargs)
         parameters = tuple([
             self.rec(child, *args, **kwargs) for child in expr.parameters
             ])
-        kw_parameters = immutabledict({
+        kw_parameters: Mapping[str, ExpressionT] = immutabledict({
                 key: self.rec(val, *args, **kwargs)
                 for key, val in expr.kw_parameters.items()})
 
@@ -505,20 +773,26 @@ class IdentityMapper(Mapper):
             return expr
         return type(expr)(function, parameters, kw_parameters)
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(self,
+                expr: p.Subscript, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         aggregate = self.rec(expr.aggregate, *args, **kwargs)
         index = self.rec(expr.index, *args, **kwargs)
         if aggregate is expr.aggregate and index is expr.index:
             return expr
         return type(expr)(aggregate, index)
 
-    def map_lookup(self, expr, *args, **kwargs):
+    def map_lookup(self,
+                expr: p.Lookup, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         aggregate = self.rec(expr.aggregate, *args, **kwargs)
         if aggregate is expr.aggregate:
             return expr
         return type(expr)(aggregate, expr.name)
 
-    def map_sum(self, expr, *args, **kwargs):
+    def map_sum(self,
+                expr: p.Sum, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         children = [self.rec(child, *args, **kwargs) for child in expr.children]
         if all(child is orig_child
                 for child, orig_child in zip(children, expr.children)):
@@ -526,41 +800,81 @@ class IdentityMapper(Mapper):
 
         return type(expr)(tuple(children))
 
-    map_product = map_sum
+    def map_product(self,
+                expr: p.Product, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        if all(child is orig_child
+                for child, orig_child in zip(children, expr.children)):
+            return expr
+
+        return type(expr)(tuple(children))
+
+    def map_quotient(self,
+                expr: p.Quotient, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        numerator = self.rec_arith(expr.numerator, *args, **kwargs)
+        denominator = self.rec_arith(expr.denominator, *args, **kwargs)
+        if numerator is expr.numerator and denominator is expr.denominator:
+            return expr
+        return expr.__class__(numerator, denominator)
 
-    def map_quotient(self, expr, *args, **kwargs):
-        numerator = self.rec(expr.numerator, *args, **kwargs)
-        denominator = self.rec(expr.denominator, *args, **kwargs)
+    def map_floor_div(self,
+                expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        numerator = self.rec_arith(expr.numerator, *args, **kwargs)
+        denominator = self.rec_arith(expr.denominator, *args, **kwargs)
         if numerator is expr.numerator and denominator is expr.denominator:
             return expr
         return expr.__class__(numerator, denominator)
 
-    map_floor_div = map_quotient
-    map_remainder = map_quotient
+    def map_remainder(self,
+                expr: p.Remainder, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        numerator = self.rec_arith(expr.numerator, *args, **kwargs)
+        denominator = self.rec_arith(expr.denominator, *args, **kwargs)
+        if numerator is expr.numerator and denominator is expr.denominator:
+            return expr
+        return expr.__class__(numerator, denominator)
 
-    def map_power(self, expr, *args, **kwargs):
-        base = self.rec(expr.base, *args, **kwargs)
-        exponent = self.rec(expr.exponent, *args, **kwargs)
+    def map_power(self,
+                expr: p.Power, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        base = self.rec_arith(expr.base, *args, **kwargs)
+        exponent = self.rec_arith(expr.exponent, *args, **kwargs)
         if base is expr.base and exponent is expr.exponent:
             return expr
         return expr.__class__(base, exponent)
 
-    def map_left_shift(self, expr, *args, **kwargs):
+    def map_left_shift(self,
+                expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         shiftee = self.rec(expr.shiftee, *args, **kwargs)
         shift = self.rec(expr.shift, *args, **kwargs)
         if shiftee is expr.shiftee and shift is expr.shift:
             return expr
         return type(expr)(shiftee, shift)
 
-    map_right_shift = map_left_shift
+    def map_right_shift(self,
+                expr: p.RightShift, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        shiftee = self.rec(expr.shiftee, *args, **kwargs)
+        shift = self.rec(expr.shift, *args, **kwargs)
+        if shiftee is expr.shiftee and shift is expr.shift:
+            return expr
+        return type(expr)(shiftee, shift)
 
-    def map_bitwise_not(self, expr, *args, **kwargs):
+    def map_bitwise_not(self,
+                expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         child = self.rec(expr.child, *args, **kwargs)
         if child is expr.child:
             return expr
         return type(expr)(child)
 
-    def map_bitwise_or(self, expr, *args, **kwargs):
+    def map_bitwise_or(self,
+                expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         children = [self.rec(child, *args, **kwargs) for child in expr.children]
         if all(child is orig_child
                 for child, orig_child in zip(children, expr.children)):
@@ -568,14 +882,57 @@ class IdentityMapper(Mapper):
 
         return type(expr)(tuple(children))
 
-    map_bitwise_xor = map_bitwise_or
-    map_bitwise_and = map_bitwise_or
+    def map_bitwise_and(self,
+                expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        if all(child is orig_child
+                for child, orig_child in zip(children, expr.children)):
+            return expr
 
-    map_logical_not = map_bitwise_not
-    map_logical_or = map_bitwise_or
-    map_logical_and = map_bitwise_or
+        return type(expr)(tuple(children))
 
-    def map_comparison(self, expr, *args, **kwargs):
+    def map_bitwise_xor(self,
+                expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        if all(child is orig_child
+                for child, orig_child in zip(children, expr.children)):
+            return expr
+
+        return type(expr)(tuple(children))
+
+    def map_logical_not(self,
+                expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        child = self.rec(expr.child, *args, **kwargs)
+        if child is expr.child:
+            return expr
+        return type(expr)(child)
+
+    def map_logical_or(self,
+                expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        if all(child is orig_child
+                for child, orig_child in zip(children, expr.children)):
+            return expr
+
+        return type(expr)(tuple(children))
+
+    def map_logical_and(self,
+                expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        if all(child is orig_child
+                for child, orig_child in zip(children, expr.children)):
+            return expr
+
+        return type(expr)(tuple(children))
+
+    def map_comparison(self,
+                expr: p.Comparison, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         left = self.rec(expr.left, *args, **kwargs)
         right = self.rec(expr.right, *args, **kwargs)
         if left is expr.left and right is expr.right:
@@ -583,10 +940,16 @@ class IdentityMapper(Mapper):
 
         return type(expr)(left, expr.operator, right)
 
-    def map_list(self, expr, *args, **kwargs):
-        return [self.rec(child, *args, **kwargs) for child in expr]
+    def map_list(self,
+                expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+
+        # True fact: lists aren't expressions
+        return [self.rec(child, *args, **kwargs) for child in expr]  # type: ignore[return-value]
 
-    def map_tuple(self, expr, *args, **kwargs):
+    def map_tuple(self,
+                expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         children = [self.rec(child, *args, **kwargs) for child in expr]
         if all(child is orig_child
                 for child, orig_child in zip(children, expr)):
@@ -594,21 +957,28 @@ class IdentityMapper(Mapper):
 
         return tuple(children)
 
-    def map_numpy_array(self, expr, *args, **kwargs):
+    def map_numpy_array(self,
+                expr: np.ndarray, *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
+
         import numpy
         result = numpy.empty(expr.shape, dtype=object)
         for i in numpy.ndindex(expr.shape):
             result[i] = self.rec(expr[i], *args, **kwargs)
-        return result
 
-    def map_multivector(self, expr, *args, **kwargs):
+        # True fact: ndarrays aren't expressions
+        return result  # type: ignore[return-value]
+
+    def map_multivector(self,
+                expr: MultiVector[ArithmeticExpressionT],
+                *args: P.args, **kwargs: P.kwargs
+            ) -> ExpressionT:
         return expr.map(lambda ch: self.rec(ch, *args, **kwargs))
 
-    def map_common_subexpression(self, expr, *args, **kwargs):
-        from pymbolic.primitives import is_zero
+    def map_common_subexpression(self,
+                expr: p.CommonSubexpression,
+                *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
         result = self.rec(expr.child, *args, **kwargs)
-        if is_zero(result):
-            return 0
         if result is expr.child:
             return expr
 
@@ -618,7 +988,9 @@ class IdentityMapper(Mapper):
                 expr.scope,
                 **expr.get_extra_properties())
 
-    def map_substitution(self, expr, *args, **kwargs):
+    def map_substitution(self,
+                 expr: p.Substitution,
+                 *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
         child = self.rec(expr.child, *args, **kwargs)
         values = tuple([self.rec(v, *args, **kwargs) for v in expr.values])
         if child is expr.child and all(val is orig_val
@@ -627,36 +999,29 @@ class IdentityMapper(Mapper):
 
         return type(expr)(child, expr.variables, values)
 
-    def map_derivative(self, expr, *args, **kwargs):
+    def map_derivative(self,
+                expr: p.Derivative,
+                *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
         child = self.rec(expr.child, *args, **kwargs)
         if child is expr.child:
             return expr
 
         return type(expr)(child, expr.variables)
 
-    def map_slice(self, expr, *args, **kwargs):
-        children = tuple([
+    def map_slice(self,
+                expr: p.Slice,
+                *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
+        children: p.SliceChildrenT = cast(p.SliceChildrenT, tuple([
             None if child is None else self.rec(child, *args, **kwargs)
             for child in expr.children
-            ])
+            ]))
         if all(child is orig_child
                 for child, orig_child in zip(children, expr.children)):
             return expr
 
         return type(expr)(children)
 
-    def map_if_positive(self, expr, *args, **kwargs):
-        criterion = self.rec(expr.criterion, *args, **kwargs)
-        then = self.rec(expr.then, *args, **kwargs)
-        else_ = self.rec(expr.else_, *args, **kwargs)
-        if criterion is expr.criterion \
-                and then is expr.then \
-                and else_ is expr.else_:
-            return expr
-
-        return type(expr)(criterion, then, else_)
-
-    def map_if(self, expr, *args, **kwargs):
+    def map_if(self, expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
         condition = self.rec(expr.condition, *args, **kwargs)
         then = self.rec(expr.then, *args, **kwargs)
         else_ = self.rec(expr.else_, *args, **kwargs)
@@ -667,7 +1032,7 @@ class IdentityMapper(Mapper):
 
         return type(expr)(condition, then, else_)
 
-    def map_min(self, expr, *args, **kwargs):
+    def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
         children = tuple([
             self.rec(child, *args, **kwargs) for child in expr.children
             ])
@@ -677,14 +1042,22 @@ class IdentityMapper(Mapper):
 
         return type(expr)(children)
 
-    map_max = map_min
+    def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
+        children = tuple([
+            self.rec(child, *args, **kwargs) for child in expr.children
+            ])
+        if all(child is orig_child
+                for child, orig_child in zip(children, expr.children)):
+            return expr
+
+        return type(expr)(children)
 
-    def map_nan(self, expr, *args, **kwargs):
+    def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> ExpressionT:
         # Leaf node -- don't recurse
         return expr
 
 
-class CachedIdentityMapper(CachedMapper, IdentityMapper):
+class CachedIdentityMapper(CachedMapper[ExpressionT, P], IdentityMapper[P]):
     pass
 
 # }}}
@@ -692,7 +1065,7 @@ class CachedIdentityMapper(CachedMapper, IdentityMapper):
 
 # {{{ walk mapper
 
-class WalkMapper(Mapper):
+class WalkMapper(Mapper[None, P]):
     """A mapper whose default mapper method implementations simply recurse
     without propagating any result. Also calls :meth:`visit` for each
     visited subexpression.
@@ -709,21 +1082,39 @@ class WalkMapper(Mapper):
         Is called after a node's children are visited.
     """
 
-    def map_constant(self, expr, *args, **kwargs):
+    def map_constant(self, expr: object, *args: P.args, **kwargs: P.kwargs) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_variable(self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_wildcard(self, expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_dot_wildcard(self,
+            expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs) -> None:
         self.visit(expr, *args, **kwargs)
         self.post_visit(expr, *args, **kwargs)
 
-    def map_variable(self, expr, *args, **kwargs):
+    def map_star_wildcard(self,
+            expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs) -> None:
         self.visit(expr, *args, **kwargs)
         self.post_visit(expr, *args, **kwargs)
 
-    map_wildcard = map_variable
-    map_dot_wildcard = map_variable
-    map_star_wildcard = map_variable
-    map_function_symbol = map_variable
-    map_nan = map_variable
+    def map_function_symbol(self,
+            expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
 
-    def map_call(self, expr, *args, **kwargs):
+    def map_nan(self,
+            expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> None:
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_call(self, expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -733,7 +1124,9 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_call_with_kwargs(self, expr, *args, **kwargs):
+    def map_call_with_kwargs(self,
+                expr: p.CallWithKwargs,
+                *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -746,7 +1139,9 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(self,
+                expr: p.Subscript,
+                *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -755,7 +1150,8 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_lookup(self, expr, *args, **kwargs):
+    def map_lookup(self,
+                expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -763,7 +1159,7 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_sum(self, expr, *args, **kwargs):
+    def map_sum(self, expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -772,9 +1168,16 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_product = map_sum
+    def map_product(self, expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
 
-    def map_quotient(self, expr, *args, **kwargs):
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_quotient(self, expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -783,10 +1186,27 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_floor_div = map_quotient
-    map_remainder = map_quotient
+    def map_floor_div(self,
+            expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        self.rec(expr.numerator, *args, **kwargs)
+        self.rec(expr.denominator, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_remainder(self,
+            expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        self.rec(expr.numerator, *args, **kwargs)
+        self.rec(expr.denominator, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
 
-    def map_power(self, expr, *args, **kwargs):
+    def map_power(self, expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -795,7 +1215,8 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_list(self, expr, *args, **kwargs):
+    def map_tuple(self,
+            expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -804,9 +1225,8 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_tuple = map_list
-
-    def map_numpy_array(self, expr, *args, **kwargs):
+    def map_numpy_array(self,
+            expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -816,16 +1236,19 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_multivector(self, expr, *args, **kwargs):
+    def map_multivector(self,
+            expr: MultiVector[ArithmeticExpressionT],
+            *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
         for _bits, coeff in expr.data.items():
-            self.rec(coeff)
+            self.rec(coeff, *args, **kwargs)
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_common_subexpression(self, expr, *args, **kwargs):
+    def map_common_subexpression(self,
+            expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -833,7 +1256,8 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_left_shift(self, expr, *args, **kwargs):
+    def map_left_shift(self,
+            expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -842,9 +1266,18 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_right_shift = map_left_shift
+    def map_right_shift(self,
+            expr: p.RightShift, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        self.rec(expr.shift, *args, **kwargs)
+        self.rec(expr.shiftee, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
 
-    def map_bitwise_not(self, expr, *args, **kwargs):
+    def map_bitwise_not(self,
+            expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -852,11 +1285,37 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_bitwise_or = map_sum
-    map_bitwise_xor = map_sum
-    map_bitwise_and = map_sum
+    def map_bitwise_or(self,
+                expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_bitwise_xor(self,
+                expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_bitwise_and(self,
+                expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
 
-    def map_comparison(self, expr, *args, **kwargs):
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_comparison(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -865,11 +1324,36 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_logical_not = map_bitwise_not
-    map_logical_and = map_sum
-    map_logical_or = map_sum
+    def map_logical_not(self,
+            expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
 
-    def map_if(self, expr, *args, **kwargs):
+        self.rec(expr.child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_logical_or(self,
+                expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_logical_and(self,
+                expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_if(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -879,7 +1363,7 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_if_positive(self, expr, *args, **kwargs):
+    def map_if_positive(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -889,11 +1373,28 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    map_min = map_sum
-    map_max = map_sum
+    def map_min(self,
+                expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
 
-    def map_substitution(self, expr, *args, **kwargs):
-        if not self.visit(expr):
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_max(self,
+                expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        for child in expr.children:
+            self.rec(child, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_substitution(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
+        if not self.visit(expr, *args, **kwargs):
             return
 
         self.rec(expr.child, *args, **kwargs)
@@ -902,7 +1403,7 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_derivative(self, expr, *args, **kwargs):
+    def map_derivative(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -910,7 +1411,7 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def map_slice(self, expr, *args, **kwargs):
+    def map_slice(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
         if not self.visit(expr, *args, **kwargs):
             return
 
@@ -923,10 +1424,10 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
-    def visit(self, expr, *args, **kwargs):
+    def visit(self, expr, *args: P.args, **kwargs: P.kwargs) -> bool:
         return True
 
-    def post_visit(self, expr, *args, **kwargs):
+    def post_visit(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
         pass
 
 
@@ -938,6 +1439,7 @@ class CachedWalkMapper(CachedMapper, WalkMapper):
 
 # {{{ callback mapper
 
+# FIXME: Is it worth typing this?
 class CallbackMapper(Mapper):
     def __init__(self, function, fallback_mapper):
         self.function = function
@@ -984,7 +1486,7 @@ class CallbackMapper(Mapper):
 
 # {{{ caching mixins
 
-class CSECachingMapperMixin(ABC):
+class CSECachingMapperMixin(ABC, Generic[ResultT, P]):
     """A :term:`mix-in` that helps
     subclassed mappers implement caching for
     :class:`pymbolic.primitives.CommonSubexpression`
@@ -999,23 +1501,28 @@ class CSECachingMapperMixin(ABC):
     This method deliberately does not support extra arguments in mapper
     dispatch, to avoid spurious dependencies of the cache on these arguments.
     """
+    _cse_cache_dict: dict[tuple[ExpressionT, P.args, P.kwargs], ResultT]
 
-    def map_common_subexpression(self, expr, *args):
+    def map_common_subexpression(self,
+                expr: p.CommonSubexpression,
+                *args: P.args, **kwargs: P.kwargs) -> ResultT:
         try:
             ccd = self._cse_cache_dict
         except AttributeError:
             ccd = self._cse_cache_dict = {}
 
-        key = (expr, *args)
+        key: tuple[ExpressionT, P.args, P.kwargs] = (expr, args, immutabledict(kwargs))
         try:
             return ccd[key]
         except KeyError:
-            result = self.map_common_subexpression_uncached(expr, *args)
+            result = self.map_common_subexpression_uncached(expr, *args, **kwargs)
             ccd[key] = result
             return result
 
     @abstractmethod
-    def map_common_subexpression_uncached(self, expr, *args):
+    def map_common_subexpression_uncached(self,
+                expr: p.CommonSubexpression,
+                *args: P.args, **kwargs: P.kwargs) -> ResultT:
         pass
 
 # }}}
diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py
index 99516f00b8384966637e4130fa7ae35f5b313a98..54e45d1c185894725515c8ea6df889bd8dd2675a 100644
--- a/pymbolic/mapper/coefficient.py
+++ b/pymbolic/mapper/coefficient.py
@@ -1,7 +1,5 @@
 from __future__ import annotations
 
-from pymbolic.primitives import flattened_product
-
 
 __copyright__ = "Copyright (C) 2013 Andreas Kloeckner"
 
@@ -25,17 +23,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from collections.abc import Collection
+from typing import Literal, Mapping, TypeAlias, cast
+
+import pymbolic.primitives as p
 from pymbolic.mapper import Mapper
+from pymbolic.typing import ArithmeticExpressionT
+
 
+CoeffsT: TypeAlias = Mapping[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT]
 
-class CoefficientCollector(Mapper):
-    def __init__(self, target_names=None):
+
+class CoefficientCollector(Mapper[CoeffsT, []]):
+    def __init__(self, target_names: Collection[str] | None = None) -> None:
         self.target_names = target_names
 
-    def map_sum(self, expr):
+    def map_sum(self, expr: p.Sum) -> CoeffsT:
         stride_dicts = [self.rec(ch) for ch in expr.children]
 
-        result = {}
+        result: dict[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] = {}
         for stride_dict in stride_dicts:
             for var, stride in stride_dict.items():
                 if var in result:
@@ -45,9 +51,7 @@ class CoefficientCollector(Mapper):
 
         return result
 
-    def map_product(self, expr):
-        result = {}
-
+    def map_product(self, expr: p.Product) -> CoeffsT:
         children_coeffs = [self.rec(child) for child in expr.children]
 
         idx_of_child_with_vars = None
@@ -60,35 +64,33 @@ class CoefficientCollector(Mapper):
                                 "nonlinear expression")
                     idx_of_child_with_vars = i
 
-        other_coeffs = 1
+        other_coeffs: ArithmeticExpressionT = 1
         for i, child_coeffs in enumerate(children_coeffs):
             if i != idx_of_child_with_vars:
                 assert len(child_coeffs) == 1
-                other_coeffs *= child_coeffs[1]
+                other_coeffs *= cast(ArithmeticExpressionT, child_coeffs[1])
 
         if idx_of_child_with_vars is None:
             return {1: other_coeffs}
         else:
             return {
-                    var: flattened_product((other_coeffs, coeff))
+                    var: p.flattened_product((other_coeffs, coeff))
                     for var, coeff in
                     children_coeffs[idx_of_child_with_vars].items()}
 
-        return result
-
-    def map_quotient(self, expr):
+    def map_quotient(self, expr: p.Quotient) -> CoeffsT:
         from pymbolic.primitives import Quotient
-        d_num = self.rec(expr.numerator)
+        d_num = dict(self.rec(expr.numerator))
         d_den = self.rec(expr.denominator)
         # d_den should look like {1: k}
         if len(d_den) > 1 or 1 not in d_den:
             raise RuntimeError("nonlinear expression")
         val = d_den[1]
         for k in d_num.keys():
-            d_num[k] = flattened_product((d_num[k], Quotient(1, val)))
+            d_num[k] = p.flattened_product((d_num[k], Quotient(1, val)))
         return d_num
 
-    def map_power(self, expr):
+    def map_power(self, expr: p.Power) -> CoeffsT:
         d_base = self.rec(expr.base)
         d_exponent = self.rec(expr.exponent)
         # d_exponent should look like {1: k}
@@ -99,11 +101,19 @@ class CoefficientCollector(Mapper):
             raise RuntimeError("nonlinear expression")
         return {1: expr}
 
-    def map_constant(self, expr):
-        return {1: expr}
+    def map_constant(self, expr: object) -> CoeffsT:
+        assert p.is_arithmetic_expression(expr)
+        from pymbolic.primitives import is_zero
+        return {} if is_zero(expr) else {1: expr}
 
-    def map_algebraic_leaf(self, expr):
+    def map_variable(self, expr: p.Variable) -> CoeffsT:
         if self.target_names is None or expr.name in self.target_names:
             return {expr: 1}
         else:
             return {1: expr}
+
+    def map_algebraic_leaf(self, expr: p.AlgebraicLeaf) -> CoeffsT:
+        if self.target_names is None:
+            return {expr: 1}
+        else:
+            return {1: expr}
diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py
index 2799cba3aa3348ea9ec2daeea8581a936266c7c6..e0c80db344b7ce35beefac7439d5d836dd8ba8dc 100644
--- a/pymbolic/mapper/collector.py
+++ b/pymbolic/mapper/collector.py
@@ -26,11 +26,16 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import AbstractSet, Sequence, cast
+
 import pymbolic
+import pymbolic.primitives as p
 from pymbolic.mapper import IdentityMapper
+from pymbolic.mapper.dependency import DependenciesT
+from pymbolic.typing import ArithmeticExpressionT, ExpressionT
 
 
-class TermCollector(IdentityMapper):
+class TermCollector(IdentityMapper[[]]):
     """A term collector that assumes that multiplication is commutative.
 
     Allows specifying *parameters* (a set of
@@ -38,16 +43,19 @@ class TermCollector(IdentityMapper):
     coefficients and are not used for term collection.
     """
 
-    def __init__(self, parameters=None):
+    def __init__(self, parameters: AbstractSet[p.AlgebraicLeaf] | None = None):
         if parameters is None:
             parameters = set()
         self.parameters = parameters
 
-    def get_dependencies(self, expr):
+    def get_dependencies(self, expr: ExpressionT) -> DependenciesT:
         from pymbolic.mapper.dependency import DependencyMapper
         return DependencyMapper()(expr)
 
-    def split_term(self, mul_term):
+    def split_term(self, mul_term: ExpressionT) -> tuple[
+        AbstractSet[tuple[ArithmeticExpressionT, ArithmeticExpressionT]],
+        ArithmeticExpressionT
+    ]:
         """Returns  a pair consisting of:
         - a frozenset of (base, exponent) pairs
         - a product of coefficients (i.e. constants and parameters)
@@ -58,20 +66,21 @@ class TermCollector(IdentityMapper):
         """
         from pymbolic.primitives import AlgebraicLeaf, Power, Product
 
-        def base(term):
+        def base(term: ExpressionT) -> ArithmeticExpressionT:
             if isinstance(term, Power):
                 return term.base
             else:
+                assert p.is_arithmetic_expression(term)
                 return term
 
-        def exponent(term):
+        def exponent(term: ExpressionT) -> ArithmeticExpressionT:
             if isinstance(term, Power):
                 return term.exponent
             else:
                 return 1
 
         if isinstance(mul_term, Product):
-            terms = mul_term.children
+            terms: Sequence[ExpressionT] = mul_term.children
         elif isinstance(mul_term, (Power, AlgebraicLeaf)):
             terms = [mul_term]
         elif not bool(self.get_dependencies(mul_term)):
@@ -79,7 +88,7 @@ class TermCollector(IdentityMapper):
         else:
             raise RuntimeError("split_term expects a multiplicative term")
 
-        base2exp = {}
+        base2exp: dict[ArithmeticExpressionT, ArithmeticExpressionT] = {}
         for term in terms:
             mybase = base(term)
             myexp = exponent(term)
@@ -91,20 +100,23 @@ class TermCollector(IdentityMapper):
 
         coefficients = []
         cleaned_base2exp = {}
-        for base, exp in base2exp.items():
-            term = base**exp
+        for item_base, item_exp in base2exp.items():
+            term = item_base**item_exp
             if self.get_dependencies(term) <= self.parameters:
                 coefficients.append(term)
             else:
-                cleaned_base2exp[base] = exp
+                cleaned_base2exp[item_base] = item_exp
 
-        term = frozenset(
+        base_exp_set = frozenset(
                 (base, exp) for base, exp in cleaned_base2exp.items())
-        return term, self.rec(pymbolic.flattened_product(coefficients))
-
-    def map_sum(self, mysum):
-        term2coeff = {}
-        for child in mysum.children:
+        return base_exp_set, cast(ArithmeticExpressionT,
+                self.rec(pymbolic.flattened_product(coefficients)))
+
+    def map_sum(self, expr: p.Sum) -> ExpressionT:
+        term2coeff: dict[
+            AbstractSet[tuple[ArithmeticExpressionT, ArithmeticExpressionT]],
+            ArithmeticExpressionT] = {}
+        for child in expr.children:
             term, coeff = self.split_term(child)
             term2coeff[term] = term2coeff.get(term, 0) + coeff
 
diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py
index 62b3eddf2179f3c77513c108c2776fd0f6b75ee6..9bf31b854c9fedeff9cb19b33bd303e91ff6238e 100644
--- a/pymbolic/mapper/constant_folder.py
+++ b/pymbolic/mapper/constant_folder.py
@@ -27,13 +27,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from collections.abc import Callable
+
 from pymbolic.mapper import (
     CSECachingMapperMixin,
     IdentityMapper,
+    Mapper,
 )
+from pymbolic.primitives import Product, Sum, is_arithmetic_expression
+from pymbolic.typing import ArithmeticExpressionT, ExpressionT
 
 
-class ConstantFoldingMapperBase:
+class ConstantFoldingMapperBase(Mapper[ExpressionT, []]):
     def is_constant(self, expr):
         from pymbolic.mapper.dependency import DependencyMapper
         return not bool(DependencyMapper()(expr))
@@ -45,15 +50,27 @@ class ConstantFoldingMapperBase:
         except ValueError:
             return None
 
-    def fold(self, expr, klass, op, constructor):
+    def fold(self,
+             expr: Sum | Product,
+             op: Callable[
+                 [ArithmeticExpressionT, ArithmeticExpressionT],
+                 ArithmeticExpressionT],
+             constructor: Callable[
+                     [tuple[ArithmeticExpressionT, ...]],
+                     ArithmeticExpressionT],
+         ) -> ExpressionT:
+        klass = type(expr)
 
-        constants = []
-        nonconstants = []
+        constants: list[ArithmeticExpressionT] = []
+        nonconstants: list[ArithmeticExpressionT] = []
 
         queue = list(expr.children)
         while queue:
-            child = self.rec(queue.pop(0))  # pylint:disable=no-member
+            child = self.rec(queue.pop(0))
+            assert is_arithmetic_expression(child)
+
             if isinstance(child, klass):
+                assert isinstance(child, (Sum, Product))
                 queue = list(child.children) + queue
             else:
                 if self.is_constant(child):
@@ -73,37 +90,36 @@ class ConstantFoldingMapperBase:
         else:
             return constructor(tuple(nonconstants))
 
-    def map_sum(self, expr):
+    def map_sum(self, expr: Sum) -> ExpressionT:
         import operator
 
-        from pymbolic.primitives import Sum, flattened_sum
+        from pymbolic.primitives import flattened_sum
 
-        return self.fold(expr, Sum, operator.add, flattened_sum)
+        return self.fold(expr, operator.add, flattened_sum)
 
 
 class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase):
     def map_product(self, expr):
         import operator
 
-        from pymbolic.primitives import Product, flattened_product
+        from pymbolic.primitives import flattened_product
 
-        return self.fold(expr, Product, operator.mul, flattened_product)
+        return self.fold(expr, operator.mul, flattened_product)
 
 
 class ConstantFoldingMapper(
-        CSECachingMapperMixin,
+        CSECachingMapperMixin[ExpressionT, []],
         ConstantFoldingMapperBase,
-        IdentityMapper):
+        IdentityMapper[[]]):
 
     map_common_subexpression_uncached = \
             IdentityMapper.map_common_subexpression
 
 
-# Yes, map_product incompatible: missing *args, **kwargs
-class CommutativeConstantFoldingMapper(    # type: ignore[misc]
-        CSECachingMapperMixin,
+class CommutativeConstantFoldingMapper(
+        CSECachingMapperMixin[ExpressionT, []],
         CommutativeConstantFoldingMapperBase,
-        IdentityMapper):
+        IdentityMapper[[]]):
 
     map_common_subexpression_uncached = \
             IdentityMapper.map_common_subexpression
diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py
index b4473e3a72d9cfdf4af0821492a02c62028fb1f4..a4e5b0f629e76a8b2a2fdfba20b417a871ed94bd 100644
--- a/pymbolic/mapper/dependency.py
+++ b/pymbolic/mapper/dependency.py
@@ -2,6 +2,7 @@
 .. autoclass:: DependencyMapper
 .. autoclass:: CachedDependencyMapper
 """
+
 from __future__ import annotations
 
 
@@ -27,10 +28,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin
+from typing import AbstractSet
+
+from typing_extensions import TypeAlias
+
+import pymbolic.primitives as p
+from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin, P
+
 
+DependenciesT: TypeAlias = AbstractSet[p.AlgebraicLeaf | p.CommonSubexpression]
 
-class DependencyMapper(CSECachingMapperMixin, Collector):
+
+class DependencyMapper(
+    CSECachingMapperMixin[DependenciesT, P],
+    Collector[p.AlgebraicLeaf | p.CommonSubexpression, P],
+):
     """Maps an expression to the :class:`set` of expressions it
     is based on. The ``include_*`` arguments to the constructor
     determine which types of objects occur in this output set.
@@ -38,12 +50,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
     instances are included.
     """
 
-    def __init__(self,
-            include_subscripts=True,
-            include_lookups=True,
-            include_calls=True,
-            include_cses=False,
-            composite_leaves=None):
+    def __init__(
+        self,
+        include_subscripts: bool = True,
+        include_lookups: bool = True,
+        include_calls: bool = True,
+        include_cses: bool = False,
+        composite_leaves: bool | None = None,
+    ):
         """
         :arg composite_leaves: Setting this is equivalent to setting
             all preceding ``include_*`` flags.
@@ -66,68 +80,92 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
 
         self.include_cses = include_cses
 
-    def map_variable(self, expr, *args, **kwargs):
+    def map_variable(
+        self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
         return {expr}
 
-    def map_call(self, expr, *args, **kwargs):
+    def map_call(
+        self, expr: p.Call, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
         if self.include_calls == "descend_args":
-            return self.combine(
-                    [self.rec(child, *args, **kwargs) for child in expr.parameters])
+            return self.combine([
+                self.rec(child, *args, **kwargs) for child in expr.parameters
+            ])
         elif self.include_calls:
             return {expr}
         else:
             return super().map_call(expr, *args, **kwargs)
 
-    def map_call_with_kwargs(self, expr, *args, **kwargs):
+    def map_call_with_kwargs(
+        self, expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
         if self.include_calls == "descend_args":
             return self.combine(
-                    [self.rec(child, *args, **kwargs) for child in expr.parameters]
-                    + [self.rec(val, *args, **kwargs) for name, val in
-                    expr.kw_parameters.items()]
-                    )
+                [self.rec(child, *args, **kwargs) for child in expr.parameters]
+                + [
+                    self.rec(val, *args, **kwargs)
+                    for name, val in expr.kw_parameters.items()
+                ]
+            )
         elif self.include_calls:
             return {expr}
         else:
             return super().map_call_with_kwargs(expr, *args, **kwargs)
 
-    def map_lookup(self, expr, *args, **kwargs):
+    def map_lookup(
+        self, expr: p.Lookup, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
         if self.include_lookups:
             return {expr}
         else:
             return super().map_lookup(expr, *args, **kwargs)
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(
+        self, expr: p.Subscript, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
         if self.include_subscripts:
             return {expr}
         else:
             return super().map_subscript(expr, *args, **kwargs)
 
-    def map_common_subexpression_uncached(self, expr, *args, **kwargs):
+    def map_common_subexpression_uncached(
+        self, expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
         if self.include_cses:
             return {expr}
         else:
-            return Collector.map_common_subexpression(self, expr, *args, **kwargs)
-
-    def map_slice(self, expr, *args, **kwargs):
-        return self.combine(
-                [self.rec(child, *args, **kwargs) for child in expr.children
-                    if child is not None])
-
-    def map_nan(self, expr, *args, **kwargs):
+            # FIXME: These look like mypy bugs, revisit
+            return Collector.map_common_subexpression(self, expr, *args, **kwargs)  # type: ignore[return-value, arg-type]
+
+    def map_slice(
+        self, expr: p.Slice, *args: P.args, **kwargs: P.kwargs
+    ) -> DependenciesT:
+        return self.combine([
+            self.rec(child, *args, **kwargs)
+            for child in expr.children
+            if child is not None
+        ])
+
+    def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> DependenciesT:
         return set()
 
 
 class CachedDependencyMapper(CachedMapper, DependencyMapper):
-    def __init__(self,
-                 include_subscripts=True,
-                 include_lookups=True,
-                 include_calls=True,
-                 include_cses=False,
-                 composite_leaves=None):
+    def __init__(
+        self,
+        include_subscripts=True,
+        include_lookups=True,
+        include_calls=True,
+        include_cses=False,
+        composite_leaves=None,
+    ):
         CachedMapper.__init__(self)
-        DependencyMapper.__init__(self,
-                                  include_subscripts=include_subscripts,
-                                  include_lookups=include_lookups,
-                                  include_calls=include_calls,
-                                  include_cses=include_cses,
-                                  composite_leaves=composite_leaves)
+        DependencyMapper.__init__(
+            self,
+            include_subscripts=include_subscripts,
+            include_lookups=include_lookups,
+            include_calls=include_calls,
+            include_cses=include_cses,
+            composite_leaves=composite_leaves,
+        )
diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py
index faa523e651e6a363bfa53b5e6326a79cee50f671..85d13504f1af639175cefb10a02d515c7628c7df 100644
--- a/pymbolic/mapper/distributor.py
+++ b/pymbolic/mapper/distributor.py
@@ -27,14 +27,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import cast
+
 import pymbolic
+import pymbolic.primitives as p
 from pymbolic.mapper import IdentityMapper
 from pymbolic.mapper.collector import TermCollector
 from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
-from pymbolic.primitives import Product, Sum, is_zero
+from pymbolic.typing import ArithmeticExpressionT, ExpressionT
 
 
-class DistributeMapper(IdentityMapper):
+class DistributeMapper(IdentityMapper[[]]):
     """Example usage:
 
     .. doctest::
@@ -47,7 +50,7 @@ class DistributeMapper(IdentityMapper):
         7*x**6 + 21*x**5 + 21*x**2 + 35*x**3 + 1 + 35*x**4 + 7*x + x**7
     """
 
-    def __init__(self, collector=None, const_folder=None):
+    def __init__(self, collector=None, const_folder=None) -> None:
         if collector is None:
             collector = TermCollector()
         if const_folder is None:
@@ -61,19 +64,19 @@ class DistributeMapper(IdentityMapper):
 
     def map_sum(self, expr):
         res = IdentityMapper.map_sum(self, expr)
-        if isinstance(res, Sum):
+        if isinstance(res, p.Sum):
             return self.collect(res)
         else:
             return res
 
-    def map_product(self, expr):
+    def map_product(self, expr: p.Product) -> ExpressionT:
         def dist(prod):
-            if not isinstance(prod, Product):
+            if not isinstance(prod, p.Product):
                 return prod
 
             leading = []
             for i in prod.children:
-                if isinstance(i, Sum):
+                if isinstance(i, p.Sum):
                     break
                 else:
                     leading.append(i)
@@ -84,10 +87,10 @@ class DistributeMapper(IdentityMapper):
                 return result
             else:
                 sum = prod.children[len(leading)]
-                assert isinstance(sum, Sum)
+                assert isinstance(sum, p.Sum)
                 rest = prod.children[len(leading)+1:]
                 if rest:
-                    rest = dist(Product(rest))
+                    rest = dist(p.Product(rest))
                 else:
                     rest = 1
 
@@ -100,7 +103,7 @@ class DistributeMapper(IdentityMapper):
         return dist(IdentityMapper.map_product(self, expr))
 
     def map_quotient(self, expr):
-        if is_zero(expr.numerator - 1):
+        if p.is_zero(expr.numerator - 1):
             return expr
         else:
             # not the smartest thing we can do, but at least *something*
@@ -109,18 +112,19 @@ class DistributeMapper(IdentityMapper):
                     self.rec(expr.numerator)
                     ])
 
-    def map_power(self, expr):
+    def map_power(self, expr: p.Power) -> ExpressionT:
         from pymbolic.primitives import Sum
 
         newbase = self.rec(expr.base)
-        if isinstance(expr.base, Product):
+        if isinstance(newbase, p.Product):
             return self.rec(pymbolic.flattened_product([
-                child**expr.exponent for child in newbase
+                cast(ArithmeticExpressionT, child)**expr.exponent
+                    for child in newbase.children
                 ]))
 
         if isinstance(expr.exponent, int):
             if isinstance(newbase, Sum):
-                return self.map_product(
+                return self.rec(
                         pymbolic.flattened_product(
                             expr.exponent*(newbase,)))
             else:
@@ -129,7 +133,7 @@ class DistributeMapper(IdentityMapper):
             return IdentityMapper.map_power(self, expr)
 
 
-def distribute(expr, parameters=None, commutative=True):
+def distribute(expr: ExpressionT, parameters=None, commutative=True) -> ExpressionT:
     if parameters is None:
         parameters = frozenset()
     if commutative:
diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py
index 44e9966ba9484f7957caac0730a8089b9299aa2d..67b7334f117c9d5c643ef53c0786cfa7ff944565 100644
--- a/pymbolic/mapper/evaluator.py
+++ b/pymbolic/mapper/evaluator.py
@@ -33,19 +33,27 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-
 import operator as op
+from collections.abc import Mapping
 from functools import reduce
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
+import pymbolic.primitives as p
 from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper
+from pymbolic.typing import ExpressionT
+
+
+if TYPE_CHECKING:
+    import numpy as np
+
+    from pymbolic.geometric_algebra import MultiVector
 
 
 class UnknownVariableError(Exception):
     pass
 
 
-class EvaluationMapper(Mapper, CSECachingMapperMixin):
+class EvaluationMapper(Mapper[Any, []], CSECachingMapperMixin):
     """Example usage:
 
     .. doctest::
@@ -62,7 +70,9 @@ class EvaluationMapper(Mapper, CSECachingMapperMixin):
         110
     """
 
-    def __init__(self, context=None):
+    context: Mapping[str, Any]
+
+    def __init__(self, context: Mapping[str, Any] | None = None) -> None:
         """
         :arg context: a mapping from variable names to values
         """
@@ -70,21 +80,20 @@ class EvaluationMapper(Mapper, CSECachingMapperMixin):
             context = {}
 
         self.context = context
-        self.common_subexp_cache = {}
 
-    def map_constant(self, expr):
+    def map_constant(self, expr: object) -> Any:
         return expr
 
-    def map_variable(self, expr):
+    def map_variable(self, expr: p.Variable) -> None:
         try:
             return self.context[expr.name]
         except KeyError:
             raise UnknownVariableError(expr.name) from None
 
-    def map_call(self, expr):
+    def map_call(self, expr: p.Call) -> Any:
         return self.rec(expr.function)(*[self.rec(par) for par in expr.parameters])
 
-    def map_call_with_kwargs(self, expr):
+    def map_call_with_kwargs(self, expr: p.CallWithKwargs) -> Any:
         args = [self.rec(par) for par in expr.parameters]
         kwargs = {
                 k: self.rec(v)
@@ -92,109 +101,97 @@ class EvaluationMapper(Mapper, CSECachingMapperMixin):
 
         return self.rec(expr.function)(*args, **kwargs)
 
-    def map_subscript(self, expr):
-        rec_result = self.rec(expr.aggregate)
+    def map_subscript(self, expr: p.Subscript) -> Any:
+        return self.rec(expr.aggregate)[self.rec(expr.index)]
 
-        from pymbolic.primitives import Expression
-        if isinstance(rec_result, Expression):
-            return rec_result.index(self.rec(expr.index))
-        else:
-            return rec_result[self.rec(expr.index)]
-
-    def map_lookup(self, expr):
+    def map_lookup(self, expr: p.Lookup) -> Any:
         return getattr(self.rec(expr.aggregate), expr.name)
 
-    def map_sum(self, expr):
+    def map_sum(self, expr: p.Sum) -> Any:
         return sum(self.rec(child) for child in expr.children)
 
-    def map_product(self, expr):
+    def map_product(self, expr: p.Product) -> Any:
         from pytools import product
         return product(self.rec(child) for child in expr.children)
 
-    def map_quotient(self, expr):
+    def map_quotient(self, expr: p.Quotient) -> Any:
         return self.rec(expr.numerator) / self.rec(expr.denominator)
 
-    def map_floor_div(self, expr):
+    def map_floor_div(self, expr: p.FloorDiv) -> Any:
         return self.rec(expr.numerator) // self.rec(expr.denominator)
 
-    def map_remainder(self, expr):
+    def map_remainder(self, expr: p.Remainder) -> Any:
         return self.rec(expr.numerator) % self.rec(expr.denominator)
 
-    def map_power(self, expr):
+    def map_power(self, expr: p.Power) -> Any:
         return self.rec(expr.base) ** self.rec(expr.exponent)
 
-    def map_left_shift(self, expr):
+    def map_left_shift(self, expr: p.LeftShift) -> Any:
         return self.rec(expr.shiftee) << self.rec(expr.shift)
 
-    def map_right_shift(self, expr):
+    def map_right_shift(self, expr: p.RightShift) -> Any:
         return self.rec(expr.shiftee) >> self.rec(expr.shift)
 
-    def map_bitwise_not(self, expr):
+    def map_bitwise_not(self, expr: p.BitwiseNot) -> Any:
         # ??? Why, pylint, why ???
         # pylint: disable=invalid-unary-operand-type
         return ~self.rec(expr.child)
 
-    def map_bitwise_or(self, expr):
+    def map_bitwise_or(self, expr: p.BitwiseOr) -> Any:
         return reduce(op.or_, (self.rec(ch) for ch in expr.children))
 
-    def map_bitwise_xor(self, expr):
+    def map_bitwise_xor(self, expr: p.BitwiseXor) -> Any:
         return reduce(op.xor, (self.rec(ch) for ch in expr.children))
 
-    def map_bitwise_and(self, expr):
+    def map_bitwise_and(self, expr: p.BitwiseAnd) -> Any:
         return reduce(op.and_, (self.rec(ch) for ch in expr.children))
 
-    def map_logical_not(self, expr):
+    def map_logical_not(self, expr: p.LogicalNot) -> Any:
         return not self.rec(expr.child)
 
-    def map_logical_or(self, expr):
+    def map_logical_or(self, expr: p.LogicalOr) -> Any:
         return any(self.rec(ch) for ch in expr.children)
 
-    def map_logical_and(self, expr):
+    def map_logical_and(self, expr: p.LogicalAnd) -> Any:
         return all(self.rec(ch) for ch in expr.children)
 
-    def map_list(self, expr):
+    def map_list(self, expr: list[ExpressionT]) -> Any:
         return [self.rec(child) for child in expr]
 
-    def map_numpy_array(self, expr):
+    def map_numpy_array(self, expr: np.ndarray) -> Any:
         import numpy
         result = numpy.empty(expr.shape, dtype=object)
         for i in numpy.ndindex(expr.shape):
             result[i] = self.rec(expr[i])
         return result
 
-    def map_multivector(self, expr, *args):
-        return expr.map(lambda ch: self.rec(ch, *args))
+    def map_multivector(self, expr: MultiVector) -> Any:
+        return expr.map(lambda ch: self.rec(ch))
 
-    def map_common_subexpression_uncached(self, expr):
+    def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> Any:
         return self.rec(expr.child)
 
-    def map_if_positive(self, expr):
-        if self.rec(expr.criterion) > 0:
+    def map_if(self, expr: p.If) -> Any:
+        if self.rec(expr.condition):
             return self.rec(expr.then)
         else:
             return self.rec(expr.else_)
 
-    def map_comparison(self, expr):
+    def map_comparison(self, expr: p.Comparison) -> Any:
         import operator
         return getattr(operator, expr.operator_to_name[expr.operator])(
             self.rec(expr.left), self.rec(expr.right))
 
-    def map_if(self, expr):
-        if self.rec(expr.condition):
-            return self.rec(expr.then)
-        else:
-            return self.rec(expr.else_)
-
-    def map_min(self, expr):
+    def map_min(self, expr: p.Min) -> Any:
         return min(self.rec(child) for child in expr.children)
 
-    def map_max(self, expr):
+    def map_max(self, expr: p.Max) -> Any:
         return max(self.rec(child) for child in expr.children)
 
-    def map_tuple(self, expr):
+    def map_tuple(self, expr: tuple[ExpressionT, ...]) -> Any:
         return tuple([self.rec(child) for child in expr])
 
-    def map_nan(self, expr):
+    def map_nan(self, expr: p.NaN) -> Any:
         if expr.data_type is None:
             from math import nan
             return nan
diff --git a/pymbolic/mapper/flattener.py b/pymbolic/mapper/flattener.py
index 4041206e6478b6b11582abe683ff68a7d88865c6..121cf65797657c6403b5c04d952974878484cc65 100644
--- a/pymbolic/mapper/flattener.py
+++ b/pymbolic/mapper/flattener.py
@@ -31,10 +31,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import cast
 
 import pymbolic.primitives as p
 from pymbolic.mapper import IdentityMapper
-from pymbolic.typing import ExpressionT
+from pymbolic.typing import ArithmeticExpressionT, ArithmeticOrExpressionT, ExpressionT
 
 
 class FlattenMapper(IdentityMapper[[]]):
@@ -50,16 +51,19 @@ class FlattenMapper(IdentityMapper[[]]):
     """
     def map_sum(self, expr: p.Sum) -> ExpressionT:
         from pymbolic.primitives import flattened_sum
-        return flattened_sum([self.rec(ch) for ch in expr.children])
+        return flattened_sum([
+                             cast(ArithmeticExpressionT, self.rec(ch))
+                             for ch in expr.children])
 
     def map_product(self, expr: p.Product) -> ExpressionT:
         from pymbolic.primitives import flattened_product
-        return flattened_product([self.rec(ch) for ch in expr.children])
+        return flattened_product([
+                                 cast(ArithmeticExpressionT, self.rec(ch))
+                                 for ch in expr.children])
 
     def map_quotient(self, expr: p.Quotient) -> ExpressionT:
-        r_num = self.rec(expr.numerator)
-        r_den = self.rec(expr.denominator)
-        assert p.is_arithmetic_expression(r_den)
+        r_num = self.rec_arith(expr.numerator)
+        r_den = self.rec_arith(expr.denominator)
         if p.is_zero(r_num):
             return 0
         if p.is_zero(r_den - 1):
@@ -68,9 +72,8 @@ class FlattenMapper(IdentityMapper[[]]):
         return expr.__class__(r_num, r_den)
 
     def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT:
-        r_num = self.rec(expr.numerator)
-        r_den = self.rec(expr.denominator)
-        assert p.is_arithmetic_expression(r_den)
+        r_num = self.rec_arith(expr.numerator)
+        r_den = self.rec_arith(expr.denominator)
         if p.is_zero(r_num):
             return 0
         if p.is_zero(r_den - 1):
@@ -79,8 +82,8 @@ class FlattenMapper(IdentityMapper[[]]):
         return expr.__class__(r_num, r_den)
 
     def map_remainder(self, expr: p.Remainder) -> ExpressionT:
-        r_num = self.rec(expr.numerator)
-        r_den = self.rec(expr.denominator)
+        r_num = self.rec_arith(expr.numerator)
+        r_den = self.rec_arith(expr.denominator)
         assert p.is_arithmetic_expression(r_den)
         if p.is_zero(r_num):
             return 0
@@ -90,10 +93,8 @@ class FlattenMapper(IdentityMapper[[]]):
         return expr.__class__(r_num, r_den)
 
     def map_power(self, expr: p.Power) -> ExpressionT:
-        r_base = self.rec(expr.base)
-        r_exp = self.rec(expr.exponent)
-
-        assert p.is_arithmetic_expression(r_exp)
+        r_base = self.rec_arith(expr.base)
+        r_exp = self.rec_arith(expr.exponent)
 
         if p.is_zero(r_exp - 1):
             return r_base
@@ -101,5 +102,5 @@ class FlattenMapper(IdentityMapper[[]]):
         return expr.__class__(r_base, r_exp)
 
 
-def flatten(expr):
-    return FlattenMapper()(expr)
+def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT:
+    return cast(ArithmeticOrExpressionT, FlattenMapper()(expr))
diff --git a/pymbolic/mapper/optimize.py b/pymbolic/mapper/optimize.py
index b7c848b1f3af5b2a9de70668407c9fe5b17da482..14a2987e958fcaa5ee3d6d01f761775e38b3500f 100644
--- a/pymbolic/mapper/optimize.py
+++ b/pymbolic/mapper/optimize.py
@@ -129,15 +129,15 @@ class _RecInliner(ast.NodeTransformer):
         self.inline_rec = inline_rec
         self.inline_cache = inline_cache
 
-    def visit_Call(self, node):  # noqa: N802
-        node = self.generic_visit(node)
+    def visit_Call(self, node: ast.Call) -> ast.AST:  # noqa: N802
+        node = cast(ast.Call, self.generic_visit(node))
 
-        result_expr = node
+        result_expr: ast.expr = node
 
         if (isinstance(node.func, ast.Attribute)
                 and isinstance(node.func.value, ast.Name)
                 and node.func.value.id == "self"
-                and node.func.attr == "rec"):
+                and node.func.attr in ["rec", "rec_arith"]):
 
             from ast import (
                 Attribute,
@@ -191,7 +191,7 @@ class _RecInliner(ast.NodeTransformer):
                         args=[expr],
                         keywords=[])
                 cache_key_expr = ast.Tuple([expr_type, expr], ctx=Load())
-                nic = Name(id="_NOT_IN_CACHE", ctx=Load())
+                nic = Name(id="_NotInCache", ctx=Load())
 
                 result_expr = IfExp(
                         test=Compare(
diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py
index 4fc08e1b07c6135c9ec54512e05cfc0713ad80cc..46ba83f5fcdcb6ee79e648dc065b6396f093222f 100644
--- a/pymbolic/mapper/stringifier.py
+++ b/pymbolic/mapper/stringifier.py
@@ -22,11 +22,21 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
+from collections.abc import Sequence
+from typing import TYPE_CHECKING, ClassVar, Concatenate
+from warnings import warn
 
-from typing import ClassVar
+from typing_extensions import deprecated
 
 import pymbolic.primitives as p
-from pymbolic.mapper import CachedMapper, Mapper
+from pymbolic.mapper import CachedMapper, Mapper, P
+from pymbolic.typing import ExpressionT
+
+
+if TYPE_CHECKING:
+    import numpy as np
+
+    from pymbolic.geometric_algebra import MultiVector
 
 
 __doc__ = """
@@ -81,7 +91,8 @@ PREC_NONE = 0
 
 # {{{ stringifier
 
-class StringifyMapper(Mapper):
+
+class StringifyMapper(Mapper[str, Concatenate[int, P]]):
     """A mapper to turn an expression tree into a string.
 
     :class:`pymbolic.Expression.__str__` is often implemented using
@@ -94,13 +105,25 @@ class StringifyMapper(Mapper):
 
     # {{{ replaceable string composition interface
 
-    def format(self, s, *args):
+    def format(self, s: str, *args: object) -> str:
         return s % args
 
-    def join(self, joiner, iterable):
-        return self.format(joiner.join("%s" for _ in iterable), *iterable)
+    def join(self, joiner: str, seq: Sequence[ExpressionT]) -> str:
+        return self.format(joiner.join("%s" for _ in seq), *seq)
+
+    # {{{ deprecated junk
 
+    @deprecated("interface not type-safe, use rec_with_parens_around_types")
     def rec_with_force_parens_around(self, expr, *args, **kwargs):
+        warn(
+            "rec_with_force_parens_around is deprecated and will be removed in 2025. "
+            "Use rec_with_parens_around_types instead. ",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        # Not currently possible to make this type-safe:
+        # https://peps.python.org/pep-0612/#concatenating-keyword-parameters
+
         force_parens_around = kwargs.pop("force_parens_around", ())
 
         result = self.rec(expr, *args, **kwargs)
@@ -110,16 +133,77 @@ class StringifyMapper(Mapper):
 
         return result
 
-    def join_rec(self, joiner, iterable, prec, *args, **kwargs):
-        f = joiner.join("%s" for _ in iterable)
-        return self.format(f,
-                *[self.rec_with_force_parens_around(i, prec, *args, **kwargs)
-                    for i in iterable])
+    def join_rec(
+        self,
+        joiner: str,
+        seq: Sequence[ExpressionT],
+        prec: int,
+        *args,
+        **kwargs,  # force_with_parens_around may hide in here
+    ) -> str:
+        f = joiner.join("%s" for _ in seq)
+
+        if "force_parens_around" in kwargs:
+            warn(
+                "Passing force_parens_around join_rec is deprecated and will be "
+                "removed in 2025. "
+                "Use join_rec_with_parens_around_types instead. ",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            # Not currently possible to make this type-safe:
+            # https://peps.python.org/pep-0612/#concatenating-keyword-parameters
+            parens_around_types: tuple[type, ...] = kwargs.pop("force_parens_around")
+            return self.join_rec_with_parens_around_types(
+                joiner, seq, prec, parens_around_types, *args, **kwargs
+            )
+
+        return self.format(
+            f,
+            *[self.rec(i, prec, *args, **kwargs) for i in seq],
+        )
+
+    # }}}
+
+    def rec_with_parens_around_types(
+        self,
+        expr: ExpressionT,
+        enclosing_prec: int,
+        parens_around: tuple[type, ...],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
+        result = self.rec(expr, enclosing_prec, *args, **kwargs)
+
+        if isinstance(expr, parens_around):
+            result = f"({result})"
 
-    def parenthesize(self, s):
+        return result
+
+    def join_rec_with_parens_around_types(
+        self,
+        joiner: str,
+        seq: Sequence[ExpressionT],
+        prec: int,
+        parens_around_types: tuple[type, ...],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
+        f = joiner.join("%s" for _ in seq)
+        return self.format(
+            f,
+            *[
+                self.rec_with_parens_around_types(
+                    i, prec, parens_around_types, *args, **kwargs
+                )
+                for i in seq
+            ],
+        )
+
+    def parenthesize(self, s: str) -> str:
         return f"({s})"
 
-    def parenthesize_if_needed(self, s, enclosing_prec, my_prec):
+    def parenthesize_if_needed(self, s: str, enclosing_prec: int, my_prec: int) -> str:
         if enclosing_prec > my_prec:
             return f"({s})"
         else:
@@ -129,210 +213,391 @@ class StringifyMapper(Mapper):
 
     # {{{ mappings
 
-    def handle_unsupported_expression(self, expr, enclosing_prec, *args, **kwargs):
+    def handle_unsupported_expression(
+        self, expr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         strifier = expr.make_stringifier(self)
         if isinstance(self, type(strifier)):
-            raise ValueError(
-                    f"stringifier '{self}' can't handle '{expr.__class__}'")
-        return strifier(
-                expr, enclosing_prec, *args, **kwargs)
+            raise ValueError(f"stringifier '{self}' can't handle '{expr.__class__}'")
+        return strifier(expr, enclosing_prec, *args, **kwargs)
 
-    def map_constant(self, expr, enclosing_prec, *args, **kwargs):
+    def map_constant(
+        self, expr: object, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         result = str(expr)
 
-        if not (result.startswith("(") and result.endswith(")")) \
-                and ("-" in result or "+" in result) \
-                and (enclosing_prec > PREC_SUM):
+        if (
+            not (result.startswith("(") and result.endswith(")"))
+            and ("-" in result or "+" in result)
+            and (enclosing_prec > PREC_SUM)
+        ):
             return self.parenthesize(result)
         else:
             return result
 
-    def map_variable(self, expr, enclosing_prec, *args, **kwargs):
+    def map_variable(
+        self, expr: p.Variable, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return expr.name
 
-    def map_wildcard(self, expr, enclosing_prec, *args, **kwargs):
+    def map_wildcard(
+        self, expr: p.Wildcard, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return "*"
 
-    def map_function_symbol(self, expr, enclosing_prec, *args, **kwargs):
+    def map_function_symbol(
+        self,
+        expr: p.FunctionSymbol,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return expr.__class__.__name__
 
-    def map_call(self, expr, enclosing_prec, *args, **kwargs):
-        return self.format("%s(%s)",
-                self.rec(expr.function, PREC_CALL, *args, **kwargs),
-                self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs))
-
-    def map_call_with_kwargs(self, expr, enclosing_prec, *args, **kwargs):
-        args_strings = (
-                tuple([
-                    self.rec(ch, PREC_NONE, *args, **kwargs)
-                    for ch in expr.parameters
-                    ])
-                +
-                tuple([
-                    "{}={}".format(name, self.rec(ch, PREC_NONE, *args, **kwargs))
-                    for name, ch in expr.kw_parameters.items()
-                    ])
-                )
-        return self.format("%s(%s)",
-                self.rec(expr.function, PREC_CALL, *args, **kwargs),
-                ", ".join(args_strings))
-
-    def map_subscript(self, expr, enclosing_prec, *args, **kwargs):
+    def map_call(
+        self, expr: p.Call, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
+        return self.format(
+            "%s(%s)",
+            self.rec(expr.function, PREC_CALL, *args, **kwargs),
+            self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs),
+        )
+
+    def map_call_with_kwargs(
+        self,
+        expr: p.CallWithKwargs,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
+        args_strings = tuple([
+            self.rec(ch, PREC_NONE, *args, **kwargs) for ch in expr.parameters
+        ]) + tuple([
+            "{}={}".format(name, self.rec(ch, PREC_NONE, *args, **kwargs))
+            for name, ch in expr.kw_parameters.items()
+        ])
+        return self.format(
+            "%s(%s)",
+            self.rec(expr.function, PREC_CALL, *args, **kwargs),
+            ", ".join(args_strings),
+        )
+
+    def map_subscript(
+        self, expr: p.Subscript, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         if isinstance(expr.index, tuple):
             index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs)
         else:
             index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs)
 
         return self.parenthesize_if_needed(
-                self.format("%s[%s]",
-                    self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
-                    index_str),
-                enclosing_prec, PREC_CALL)
-
-    def map_lookup(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                "%s[%s]",
+                self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
+                index_str,
+            ),
+            enclosing_prec,
+            PREC_CALL,
+        )
+
+    def map_lookup(
+        self, expr: p.Lookup, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s.%s",
-                    self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
-                    expr.name),
-                enclosing_prec, PREC_CALL)
-
-    def map_sum(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                "%s.%s", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), expr.name
+            ),
+            enclosing_prec,
+            PREC_CALL,
+        )
+
+    def map_sum(
+        self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs),
-                enclosing_prec, PREC_SUM)
+            self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs),
+            enclosing_prec,
+            PREC_SUM,
+        )
 
     # {{{ multiplicative operators
 
     multiplicative_primitives = (p.Product, p.Quotient, p.FloorDiv, p.Remainder)
 
-    def map_product(self, expr, enclosing_prec, *args, **kwargs):
-        kwargs["force_parens_around"] = (p.Quotient, p.FloorDiv, p.Remainder)
+    def map_product(
+        self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs),
-                enclosing_prec, PREC_PRODUCT)
-
-    def map_quotient(self, expr, enclosing_prec, *args, **kwargs):
-        kwargs["force_parens_around"] = self.multiplicative_primitives
+            self.join_rec_with_parens_around_types(
+                "*",
+                expr.children,
+                PREC_PRODUCT,
+                (p.Quotient, p.FloorDiv, p.Remainder),
+                *args,
+                **kwargs,
+            ),
+            enclosing_prec,
+            PREC_PRODUCT,
+        )
+
+    def map_quotient(
+        self, expr: p.Quotient, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s / %s",
-                    # space is necessary--otherwise '/*' becomes
-                    # start-of-comment in C. ('*' from dereference)
-                    self.rec_with_force_parens_around(expr.numerator, PREC_PRODUCT,
-                        *args, **kwargs),
-                    self.rec_with_force_parens_around(
-                        expr.denominator, PREC_PRODUCT, *args, **kwargs)),
-                enclosing_prec, PREC_PRODUCT)
-
-    def map_floor_div(self, expr, enclosing_prec, *args, **kwargs):
-        kwargs["force_parens_around"] = self.multiplicative_primitives
+            self.format(
+                "%s / %s",
+                # space is necessary--otherwise '/*' becomes
+                # start-of-comment in C. ('*' from dereference)
+                self.rec_with_parens_around_types(
+                    expr.numerator,
+                    PREC_PRODUCT,
+                    self.multiplicative_primitives,
+                    *args,
+                    **kwargs,
+                ),
+                self.rec_with_parens_around_types(
+                    expr.denominator,
+                    PREC_PRODUCT,
+                    self.multiplicative_primitives,
+                    *args,
+                    **kwargs,
+                ),
+            ),
+            enclosing_prec,
+            PREC_PRODUCT,
+        )
+
+    def map_floor_div(
+        self, expr: p.FloorDiv, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s // %s",
-                    self.rec_with_force_parens_around(
-                        expr.numerator, PREC_PRODUCT, *args, **kwargs),
-                    self.rec_with_force_parens_around(
-                        expr.denominator, PREC_PRODUCT, *args, **kwargs)),
-                enclosing_prec, PREC_PRODUCT)
-
-    def map_remainder(self, expr, enclosing_prec, *args, **kwargs):
-        kwargs["force_parens_around"] = self.multiplicative_primitives
+            self.format(
+                "%s // %s",
+                self.rec_with_parens_around_types(
+                    expr.numerator,
+                    PREC_PRODUCT,
+                    self.multiplicative_primitives,
+                    *args,
+                    **kwargs,
+                ),
+                self.rec_with_parens_around_types(
+                    expr.denominator,
+                    PREC_PRODUCT,
+                    self.multiplicative_primitives,
+                    *args,
+                    **kwargs,
+                ),
+            ),
+            enclosing_prec,
+            PREC_PRODUCT,
+        )
+
+    def map_remainder(
+        self, expr: p.Remainder, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s %% %s",
-                    self.rec_with_force_parens_around(
-                        expr.numerator, PREC_PRODUCT, *args, **kwargs),
-                    self.rec_with_force_parens_around(
-                        expr.denominator, PREC_PRODUCT, *args, **kwargs)),
-                enclosing_prec, PREC_PRODUCT)
+            self.format(
+                "%s %% %s",
+                self.rec_with_parens_around_types(
+                    expr.numerator,
+                    PREC_PRODUCT,
+                    self.multiplicative_primitives,
+                    *args,
+                    **kwargs,
+                ),
+                self.rec_with_parens_around_types(
+                    expr.denominator,
+                    PREC_PRODUCT,
+                    self.multiplicative_primitives,
+                    *args,
+                    **kwargs,
+                ),
+            ),
+            enclosing_prec,
+            PREC_PRODUCT,
+        )
 
     # }}}
 
-    def map_power(self, expr, enclosing_prec, *args, **kwargs):
+    def map_power(
+        self, expr: p.Power, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s**%s",
-                    self.rec(expr.base, PREC_POWER, *args, **kwargs),
-                    self.rec(expr.exponent, PREC_POWER, *args, **kwargs)),
-                enclosing_prec, PREC_POWER)
-
-    def map_left_shift(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                "%s**%s",
+                self.rec(expr.base, PREC_POWER, *args, **kwargs),
+                self.rec(expr.exponent, PREC_POWER, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_POWER,
+        )
+
+    def map_left_shift(
+        self, expr: p.LeftShift, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                # +1 to address
-                # https://gitlab.tiker.net/inducer/pymbolic/issues/6
-                self.format("%s << %s",
-                    self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs),
-                    self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)),
-                enclosing_prec, PREC_SHIFT)
-
-    def map_right_shift(self, expr, enclosing_prec, *args, **kwargs):
+            # +1 to address
+            # https://gitlab.tiker.net/inducer/pymbolic/issues/6
+            self.format(
+                "%s << %s",
+                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
+                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_SHIFT,
+        )
+
+    def map_right_shift(
+        self,
+        expr: p.RightShift,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                # +1 to address
-                # https://gitlab.tiker.net/inducer/pymbolic/issues/6
-                self.format("%s >> %s",
-                    self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs),
-                    self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)),
-                enclosing_prec, PREC_SHIFT)
-
-    def map_bitwise_not(self, expr, enclosing_prec, *args, **kwargs):
+            # +1 to address
+            # https://gitlab.tiker.net/inducer/pymbolic/issues/6
+            self.format(
+                "%s >> %s",
+                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
+                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_SHIFT,
+        )
+
+    def map_bitwise_not(
+        self,
+        expr: p.BitwiseNot,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
-                enclosing_prec, PREC_UNARY)
-
-    def map_bitwise_or(self, expr, enclosing_prec, *args, **kwargs):
+            "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
+            enclosing_prec,
+            PREC_UNARY,
+        )
+
+    def map_bitwise_or(
+        self, expr: p.BitwiseOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    " | ", expr.children, PREC_BITWISE_OR, *args, **kwargs),
-                enclosing_prec, PREC_BITWISE_OR)
-
-    def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(" | ", expr.children, PREC_BITWISE_OR, *args, **kwargs),
+            enclosing_prec,
+            PREC_BITWISE_OR,
+        )
+
+    def map_bitwise_xor(
+        self,
+        expr: p.BitwiseXor,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    " ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs),
-                enclosing_prec, PREC_BITWISE_XOR)
-
-    def map_bitwise_and(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs),
+            enclosing_prec,
+            PREC_BITWISE_XOR,
+        )
+
+    def map_bitwise_and(
+        self,
+        expr: p.BitwiseAnd,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    " & ", expr.children, PREC_BITWISE_AND, *args, **kwargs),
-                enclosing_prec, PREC_BITWISE_AND)
-
-    def map_comparison(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(" & ", expr.children, PREC_BITWISE_AND, *args, **kwargs),
+            enclosing_prec,
+            PREC_BITWISE_AND,
+        )
+
+    def map_comparison(
+        self, expr: p.Comparison, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s %s %s",
-                    self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
-                    expr.operator,
-                    self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)),
-                enclosing_prec, PREC_COMPARISON)
-
-    def map_logical_not(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                "%s %s %s",
+                self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
+                expr.operator,
+                self.rec(expr.right, PREC_COMPARISON, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_COMPARISON,
+        )
+
+    def map_logical_not(
+        self,
+        expr: p.LogicalNot,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
-                enclosing_prec, PREC_UNARY)
-
-    def map_logical_or(self, expr, enclosing_prec, *args, **kwargs):
+            "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
+            enclosing_prec,
+            PREC_UNARY,
+        )
+
+    def map_logical_or(
+        self, expr: p.LogicalOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    " or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
-                enclosing_prec, PREC_LOGICAL_OR)
-
-    def map_logical_and(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(" or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
+            enclosing_prec,
+            PREC_LOGICAL_OR,
+        )
+
+    def map_logical_and(
+        self,
+        expr: p.LogicalAnd,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    " and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
-                enclosing_prec, PREC_LOGICAL_AND)
-
-    def map_list(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(" and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
+            enclosing_prec,
+            PREC_LOGICAL_AND,
+        )
+
+    def map_list(
+        self,
+        expr: list[ExpressionT],
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.format(
-                "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs))
+            "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs)
+        )
 
     map_vector = map_list
 
-    def map_tuple(self, expr, enclosing_prec, *args, **kwargs):
+    def map_tuple(
+        self,
+        expr: tuple[ExpressionT, ...],
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         el_str = ", ".join(
-                self.rec(child, PREC_NONE, *args, **kwargs) for child in expr)
+            self.rec(child, PREC_NONE, *args, **kwargs) for child in expr
+        )
         if len(expr) == 1:
             el_str += ","
 
         return f"({el_str})"
 
-    def map_numpy_array(self, expr, enclosing_prec, *args, **kwargs):
+    def map_numpy_array(
+        self,
+        expr: np.ndarray,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         import numpy
 
         str_array = numpy.zeros(expr.shape, dtype="object")
@@ -345,68 +610,102 @@ class StringifyMapper(Mapper):
         if len(expr.shape) == 1 and max_length < 15:
             return "array({})".format(", ".join(str_array))
         else:
-            lines = ["  {}: {}\n".format(
-                ",".join(str(i_i) for i_i in i), val)
-                for i, val in numpy.ndenumerate(str_array)]
+            lines = [
+                "  {}: {}\n".format(",".join(str(i_i) for i_i in i), val)
+                for i, val in numpy.ndenumerate(str_array)
+            ]
             if max_length > 70:
-                splitter = "  " + "-"*75 + "\n"
+                splitter = "  " + "-" * 75 + "\n"
                 return "array(\n{})".format(splitter.join(lines))
             else:
                 return "array(\n{})".format("".join(lines))
 
-    def map_multivector(self, expr, enclosing_prec, *args, **kwargs):
+    def map_multivector(
+        self,
+        expr: MultiVector,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return expr.stringify(self.rec, enclosing_prec, *args, **kwargs)
 
-    def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs):
+    def map_common_subexpression(
+        self,
+        expr: p.CommonSubexpression,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         from pymbolic.primitives import CommonSubexpression
+
         if type(expr) is CommonSubexpression:
             type_name = "CSE"
         else:
             type_name = type(expr).__name__
 
-        return self.format("%s(%s)",
-                type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs))
-
-    def map_if(self, expr, enclosing_prec, *args, **kwargs):
-        return self.parenthesize_if_needed(
-                "{} if {} else {}".format(
-                    self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs),
-                    self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs),
-                    self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)),
-                enclosing_prec, PREC_IF)
+        return self.format(
+            "%s(%s)", type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs)
+        )
 
-    def map_if_positive(self, expr, enclosing_prec, *args, **kwargs):
+    def map_if(
+        self, expr: p.If, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                "{} if {} > 0 else {}".format(
-                    self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs),
-                    self.rec(expr.criterion, PREC_LOGICAL_OR, *args, **kwargs),
-                    self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)),
-                enclosing_prec, PREC_IF)
-
-    def map_min(self, expr, enclosing_prec, *args, **kwargs):
+            "{} if {} else {}".format(
+                self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs),
+                self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs),
+                self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_IF,
+        )
+
+    def map_min(
+        self, expr: p.Min, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         what = type(expr).__name__.lower()
-        return self.format("%s(%s)",
-                what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs))
-
-    map_max = map_min
-
-    def map_derivative(self, expr, enclosing_prec, *args, **kwargs):
+        return self.format(
+            "%s(%s)",
+            what,
+            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
+        )
+
+    def map_max(
+        self, expr: p.Max, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
+        what = type(expr).__name__.lower()
+        return self.format(
+            "%s(%s)",
+            what,
+            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
+        )
+
+    def map_derivative(
+        self, expr: p.Derivative, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         derivs = " ".join(f"d/d{v}" for v in expr.variables)
 
         return "{} {}".format(
-                derivs,
-                self.rec(expr.child, PREC_PRODUCT, *args, **kwargs))
-
-    def map_substitution(self, expr, enclosing_prec, *args, **kwargs):
+            derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)
+        )
+
+    def map_substitution(
+        self,
+        expr: p.Substitution,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         substs = ", ".join(
-                "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs))
-                for name, val in zip(expr.variables, expr.values))
+            "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs))
+            for name, val in zip(expr.variables, expr.values)
+        )
 
-        return "[%s]{%s}" % (
-                self.rec(expr.child, PREC_NONE, *args, **kwargs),
-                substs)
+        return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs)
 
-    def map_slice(self, expr, enclosing_prec, *args, **kwargs):
+    def map_slice(
+        self, expr: p.Slice, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         children = []
         for child in expr.children:
             if child is None:
@@ -415,15 +714,17 @@ class StringifyMapper(Mapper):
                 children.append(self.rec(child, PREC_NONE, *args, **kwargs))
 
         return self.parenthesize_if_needed(
-                self.join(":", children),
-                enclosing_prec, PREC_NONE)
+            ":".join(children), enclosing_prec, PREC_NONE
+        )
 
-    def map_nan(self, expr, enclosing_prec, *args, **kwargs):
+    def map_nan(
+        self, expr: p.NaN, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return "NaN"
 
     # }}}
 
-    def __call__(self, expr, prec=PREC_NONE, *args, **kwargs):
+    def __call__(self, expr, prec=PREC_NONE, *args: P.args, **kwargs: P.kwargs) -> str:
         """Return a string corresponding to *expr*. If the enclosing
         precedence level *prec* is higher than *prec* (see :ref:`prec-constants`),
         parenthesize the result.
@@ -437,15 +738,17 @@ class CachedStringifyMapper(StringifyMapper, CachedMapper):
         StringifyMapper.__init__(self)
         CachedMapper.__init__(self)
 
-    def __call__(self, expr, prec=PREC_NONE, *args, **kwargs):
+    def __call__(self, expr, prec=PREC_NONE, *args: P.args, **kwargs: P.kwargs) -> str:
         return CachedMapper.__call__(expr, prec, *args, **kwargs)
 
+
 # }}}
 
 
 # {{{ cse-splitting stringifier
 
-class CSESplittingStringifyMapperMixin:
+
+class CSESplittingStringifyMapperMixin(Mapper[str, Concatenate[int, P]]):
     """A :term:`mix-in` for subclasses of
     :class:`StringifyMapper` that collects
     "variable assignments" for
@@ -469,44 +772,45 @@ class CSESplittingStringifyMapperMixin:
     See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example
     of the use of this mix-in.
     """
-    def __init__(self):
+
+    cse_to_name: dict[ExpressionT, str]
+    cse_names: set[str]
+    cse_name_list: list[tuple[str, str]]
+
+    def __init__(self) -> None:
         self.cse_to_name = {}
         self.cse_names = set()
         self.cse_name_list = []
 
         super().__init__()
 
-    def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs):
-        # This is here for compatibility, in case the constructor did not get called.
-        try:
-            self.cse_to_name  # noqa: B018
-        except AttributeError:
-            from warnings import warn
-            warn("Constructor of CSESplittingStringifyMapperMixin did not get "
-                 "called. This is deprecated and will stop working in 2022.",
-                 DeprecationWarning, stacklevel=2)
-
-            self.cse_to_name = {}
-            self.cse_names = set()
-            self.cse_name_list = []
-
+    def map_common_subexpression(
+        self,
+        expr: p.CommonSubexpression,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         try:
             cse_name = self.cse_to_name[expr.child]
         except KeyError:
             str_child = self.rec(expr.child, PREC_NONE, *args, **kwargs)
 
             if expr.prefix is not None:
+
                 def generate_cse_names():
                     yield expr.prefix
                     i = 2
                     while True:
                         yield expr.prefix + f"_{i}"
                         i += 1
+
             else:
+
                 def generate_cse_names():
                     i = 0
                     while True:
-                        yield "CSE"+str(i)
+                        yield "CSE" + str(i)
                         i += 1
 
             for cse_name in generate_cse_names():
@@ -519,51 +823,63 @@ class CSESplittingStringifyMapperMixin:
 
         return cse_name
 
-    def get_cse_strings(self):
-        return [f"{cse_name} : {cse_str}"
-                for cse_name, cse_str in
-                    sorted(getattr(self, "cse_name_list", []))]
+    def get_cse_strings(self) -> list[str]:
+        return [
+            f"{cse_name} : {cse_str}"
+            for cse_name, cse_str in sorted(getattr(self, "cse_name_list", []))
+        ]
+
 
 # }}}
 
 
 # {{{ sorting stringifier
 
-class SortingStringifyMapper(StringifyMapper):
+
+class SortingStringifyMapper(StringifyMapper[P]):
     def __init__(self, reverse=True):
         super().__init__()
         self.reverse = reverse
 
-    def map_sum(self, expr, enclosing_prec, *args, **kwargs):
+    def map_sum(
+        self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         entries = [self.rec(i, PREC_SUM, *args, **kwargs) for i in expr.children]
         entries.sort(reverse=self.reverse)
-        return self.parenthesize_if_needed(
-                self.join(" + ", entries),
-                enclosing_prec, PREC_SUM)
+        return self.parenthesize_if_needed("+".join(entries), enclosing_prec, PREC_SUM)
 
-    def map_product(self, expr, enclosing_prec, *args, **kwargs):
+    def map_product(
+        self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         entries = [self.rec(i, PREC_PRODUCT, *args, **kwargs) for i in expr.children]
         entries.sort(reverse=self.reverse)
         return self.parenthesize_if_needed(
-                self.join("*", entries),
-                enclosing_prec, PREC_PRODUCT)
+            "*".join(entries), enclosing_prec, PREC_PRODUCT
+        )
+
 
 # }}}
 
 
 # {{{ simplifying, sorting stringifier
 
+
 class SimplifyingSortingStringifyMapper(StringifyMapper):
     def __init__(self, reverse=True):
         super().__init__()
         self.reverse = reverse
 
-    def map_sum(self, expr, enclosing_prec, *args, **kwargs):
+    def map_sum(
+        self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         def get_neg_product(expr):
             from pymbolic.primitives import Product, is_zero
 
-            if isinstance(expr, Product) \
-                    and len(expr.children) and is_zero(expr.children[0]+1):
+            if (
+                isinstance(expr, Product)
+                and len(expr.children)
+                and is_zero(expr.children[0] + 1)
+            ):
                 if len(expr.children) == 2:
                     # only the minus sign and the other child
                     return expr.children[1]
@@ -583,29 +899,32 @@ class SimplifyingSortingStringifyMapper(StringifyMapper):
                 positives.append(self.rec(ch, PREC_SUM, *args, **kwargs))
 
         positives.sort(reverse=self.reverse)
-        positives = " + ".join(positives)
+        positives_str = " + ".join(positives)
         negatives.sort(reverse=self.reverse)
-        negatives = self.join("",
-                [self.format(" - %s", entry) for entry in negatives])
+        negatives_str = "".join(self.format(" - %s", entry) for entry in negatives)
 
-        result = positives + negatives
+        result = positives_str + negatives_str
 
         return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM)
 
-    def map_product(self, expr, enclosing_prec, *args, **kwargs):
+    def map_product(
+        self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         entries = []
         i = 0
         from pymbolic.primitives import is_zero
 
         while i < len(expr.children):
             child = expr.children[i]
-            if False and is_zero(child+1) and i+1 < len(expr.children):
+            if False and is_zero(child + 1) and i + 1 < len(expr.children):
                 # NOTE: That space needs to be there.
                 # Otherwise two unary minus signs merge into a pre-decrement.
                 entries.append(
-                        self.format(
-                            "- %s", self.rec(
-                                expr.children[i+1], PREC_UNARY, *args, **kwargs)))
+                    self.format(
+                        "- %s",
+                        self.rec(expr.children[i + 1], PREC_UNARY, *args, **kwargs),
+                    )
+                )
                 i += 2
             else:
                 entries.append(self.rec(child, PREC_PRODUCT, *args, **kwargs))
@@ -616,127 +935,226 @@ class SimplifyingSortingStringifyMapper(StringifyMapper):
 
         return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT)
 
+
 # }}}
 
 
 # {{{ latex stringifier
 
-class LaTeXMapper(StringifyMapper):
 
+class LaTeXMapper(StringifyMapper):
     COMPARISON_OP_TO_LATEX: ClassVar[dict[str, str]] = {
         "==": r"=",
         "!=": r"\ne",
         "<=": r"\le",
         ">=": r"\ge",
-        "<":  r"<",
-        ">":  r">",
-        }
-
-    def map_remainder(self, expr, enclosing_prec, *args, **kwargs):
-        return self.format(r"(%s \bmod %s)",
-                self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs),
-                self.rec(expr.denominator, PREC_POWER, *args, **kwargs)),
+        "<": r"<",
+        ">": r">",
+    }
 
-    def map_left_shift(self, expr, enclosing_prec, *args, **kwargs):
+    def map_remainder(
+        self, expr: p.Remainder, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
+        return self.format(
+            r"(%s \bmod %s)",
+            self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs),
+            self.rec(expr.denominator, PREC_POWER, *args, **kwargs),
+        )
+
+    def map_left_shift(
+        self, expr: p.LeftShift, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format(r"%s \ll %s",
-                    self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs),
-                    self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)),
-                enclosing_prec, PREC_SHIFT)
-
-    def map_right_shift(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                r"%s \ll %s",
+                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
+                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_SHIFT,
+        )
+
+    def map_right_shift(
+        self,
+        expr: p.RightShift,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format(r"%s \gg %s",
-                    self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs),
-                    self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)),
-                enclosing_prec, PREC_SHIFT)
-
-    def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                r"%s \gg %s",
+                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
+                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_SHIFT,
+        )
+
+    def map_bitwise_xor(
+        self,
+        expr: p.BitwiseXor,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs),
-                enclosing_prec, PREC_BITWISE_XOR)
-
-    def map_product(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(
+                r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs
+            ),
+            enclosing_prec,
+            PREC_BITWISE_XOR,
+        )
+
+    def map_product(
+        self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs),
-                enclosing_prec, PREC_PRODUCT)
-
-    def map_power(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs),
+            enclosing_prec,
+            PREC_PRODUCT,
+        )
+
+    def map_power(
+        self, expr: p.Power, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("{%s}^{%s}",
-                    self.rec(expr.base, PREC_NONE, *args, **kwargs),
-                    self.rec(expr.exponent, PREC_NONE, *args, **kwargs)),
-                enclosing_prec, PREC_NONE)
-
-    def map_min(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                "{%s}^{%s}",
+                self.rec(expr.base, PREC_NONE, *args, **kwargs),
+                self.rec(expr.exponent, PREC_NONE, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_NONE,
+        )
+
+    def map_min(
+        self, expr: p.Min, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         from pytools import is_single_valued
+
         if is_single_valued(expr.children):
             return self.rec(expr.children[0], enclosing_prec)
 
         what = type(expr).__name__.lower()
-        return self.format(r"\%s(%s)",
-                what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs))
-
-    def map_max(self, expr, enclosing_prec):
-        return self.map_min(expr, enclosing_prec)
+        return self.format(
+            r"\%s(%s)",
+            what,
+            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
+        )
+
+    def map_max(
+        self, expr: p.Max, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
+        from pytools import is_single_valued
 
-    def map_floor_div(self, expr, enclosing_prec, *args, **kwargs):
-        return self.format(r"\lfloor {%s} / {%s} \rfloor",
-                    self.rec(expr.numerator, PREC_NONE, *args, **kwargs),
-                    self.rec(expr.denominator, PREC_NONE, *args, **kwargs))
+        if is_single_valued(expr.children):
+            return self.rec(expr.children[0], enclosing_prec)
 
-    def map_subscript(self, expr, enclosing_prec, *args, **kwargs):
+        what = type(expr).__name__.lower()
+        return self.format(
+            r"\%s(%s)",
+            what,
+            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
+        )
+
+    def map_floor_div(
+        self, expr: p.FloorDiv, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
+        return self.format(
+            r"\lfloor {%s} / {%s} \rfloor",
+            self.rec(expr.numerator, PREC_NONE, *args, **kwargs),
+            self.rec(expr.denominator, PREC_NONE, *args, **kwargs),
+        )
+
+    def map_subscript(
+        self, expr: p.Subscript, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         if isinstance(expr.index, tuple):
             index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs)
         else:
             index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs)
 
-        return self.format("{%s}_{%s}",
-                    self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
-                    index_str)
-
-    def map_logical_not(self, expr, enclosing_prec, *args, **kwargs):
+        return self.format(
+            "{%s}_{%s}", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), index_str
+        )
+
+    def map_logical_not(
+        self,
+        expr: p.LogicalNot,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
-                enclosing_prec, PREC_UNARY)
-
-    def map_logical_or(self, expr, enclosing_prec, *args, **kwargs):
+            r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
+            enclosing_prec,
+            PREC_UNARY,
+        )
+
+    def map_logical_or(
+        self, expr: p.LogicalOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
-                enclosing_prec, PREC_LOGICAL_OR)
-
-    def map_logical_and(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
+            enclosing_prec,
+            PREC_LOGICAL_OR,
+        )
+
+    def map_logical_and(
+        self,
+        expr: p.LogicalAnd,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.join_rec(
-                    r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
-                enclosing_prec, PREC_LOGICAL_AND)
-
-    def map_comparison(self, expr, enclosing_prec, *args, **kwargs):
+            self.join_rec(
+                r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs
+            ),
+            enclosing_prec,
+            PREC_LOGICAL_AND,
+        )
+
+    def map_comparison(
+        self, expr: p.Comparison, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
         return self.parenthesize_if_needed(
-                self.format("%s %s %s",
-                    self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
-                    self.COMPARISON_OP_TO_LATEX[expr.operator],
-                    self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)),
-                enclosing_prec, PREC_COMPARISON)
-
-    def map_substitution(self, expr, enclosing_prec, *args, **kwargs):
+            self.format(
+                "%s %s %s",
+                self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
+                self.COMPARISON_OP_TO_LATEX[expr.operator],
+                self.rec(expr.right, PREC_COMPARISON, *args, **kwargs),
+            ),
+            enclosing_prec,
+            PREC_COMPARISON,
+        )
+
+    def map_substitution(
+        self,
+        expr: p.Substitution,
+        enclosing_prec: int,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> str:
         substs = ", ".join(
-                "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs))
-                for name, val in zip(expr.variables, expr.values))
+            "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs))
+            for name, val in zip(expr.variables, expr.values)
+        )
 
-        return self.format(r"[%s]\{%s\}",
-                self.rec(expr.child, PREC_NONE, *args, **kwargs),
-                substs)
+        return self.format(
+            r"[%s]\{%s\}", self.rec(expr.child, PREC_NONE, *args, **kwargs), substs
+        )
 
-    def map_derivative(self, expr, enclosing_prec, *args, **kwargs):
-        derivs = " ".join(
-                r"\frac{\partial}{\partial %s}" % v
-                for v in expr.variables)
+    def map_derivative(
+        self, expr: p.Derivative, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
+    ) -> str:
+        derivs = " ".join(r"\frac{\partial}{\partial %s}" % v for v in expr.variables)
+
+        return self.format(
+            "%s %s", derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)
+        )
 
-        return self.format("%s %s",
-                derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs))
 
 # }}}
 
diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py
index 8cdb3e12e67cb0291fc12e4b279a709193275697..04186342a002052fc8017e498d6e439fdbf0e35e 100644
--- a/pymbolic/mapper/substitutor.py
+++ b/pymbolic/mapper/substitutor.py
@@ -4,7 +4,16 @@
 .. autofunction:: make_subst_func
 .. autofunction:: substitute
 
+.. autoclass:: Callable[[AlgebraicLeaf], ExpressionT | None]
+
+References
+----------
+
+.. class:: SupportsGetItem
+
+    A protocol with a ``__getitem__`` method.
 """
+
 from __future__ import annotations
 
 
@@ -29,12 +38,19 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
+from typing import Any, Callable
+
+from useful_types import SupportsGetItem, SupportsItems
 
 from pymbolic.mapper import CachedIdentityMapper, IdentityMapper
+from pymbolic.primitives import AlgebraicLeaf
+from pymbolic.typing import ExpressionT
 
 
-class SubstitutionMapper(IdentityMapper):
-    def __init__(self, subst_func):
+class SubstitutionMapper(IdentityMapper[[]]):
+    def __init__(
+        self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None]
+    ) -> None:
         self.subst_func = subst_func
 
     def map_variable(self, expr):
@@ -59,17 +75,26 @@ class SubstitutionMapper(IdentityMapper):
             return IdentityMapper.map_lookup(self, expr)
 
 
-class CachedSubstitutionMapper(CachedIdentityMapper,
-                               SubstitutionMapper):
-    def __init__(self, subst_func):
-        CachedIdentityMapper.__init__(self)
+class CachedSubstitutionMapper(CachedIdentityMapper[[]], SubstitutionMapper):
+    def __init__(
+        self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None]
+    ) -> None:
+        # FIXME Mypy says:
+        # error: Argument 1 to "__init__" of "CachedMapper" has incompatible type
+        # "CachedSubstitutionMapper"; expected "CachedMapper[ResultT, P]"  [arg-type]
+        # This seems spurious?
+        CachedIdentityMapper.__init__(self)  # type: ignore[arg-type]
         SubstitutionMapper.__init__(self, subst_func)
 
 
-def make_subst_func(variable_assignments):
+def make_subst_func(
+    # "Any" here avoids the whole Mapping variance disaster
+    # e.g. https://github.com/python/typing/issues/445
+    variable_assignments: SupportsGetItem[Any, ExpressionT],
+) -> Callable[[AlgebraicLeaf], ExpressionT | None]:
     import pymbolic.primitives as primitives
 
-    def subst_func(var):
+    def subst_func(var: AlgebraicLeaf) -> ExpressionT | None:
         try:
             return variable_assignments[var]
         except KeyError:
@@ -84,15 +109,23 @@ def make_subst_func(variable_assignments):
     return subst_func
 
 
-def substitute(expression, variable_assignments=None,
-               mapper_cls=CachedSubstitutionMapper, **kwargs):
+def substitute(
+    expression: ExpressionT,
+    variable_assignments: SupportsItems[AlgebraicLeaf | str, ExpressionT] | None = None,
+    mapper_cls=CachedSubstitutionMapper,
+    **kwargs: ExpressionT,
+):
     """
     :arg mapper_cls: A :class:`type` of the substitution mapper
         whose instance applies the substitution.
     """
     if variable_assignments is None:
-        variable_assignments = {}
-    variable_assignments = variable_assignments.copy()
-    variable_assignments.update(kwargs)
+        # "Any" here avoids pointless grief about variance
+        # e.g. https://github.com/python/typing/issues/445
+        v_ass_copied: dict[Any, ExpressionT] = {}
+    else:
+        v_ass_copied = dict(variable_assignments.items())
+
+    v_ass_copied.update(kwargs)
 
-    return mapper_cls(make_subst_func(variable_assignments))(expression)
+    return mapper_cls(make_subst_func(v_ass_copied))(expression)
diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index 1b4977a875f7178eff32b6ca1f7e2fc8c3024d0c..251adc198144806f7a4bc16b53d7523e9f85ad52 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -38,7 +38,7 @@ from typing import (
 from warnings import warn
 
 from immutabledict import immutabledict
-from typing_extensions import TypeIs, dataclass_transform
+from typing_extensions import TypeAlias, TypeIs, dataclass_transform
 
 from . import traits
 from .typing import ArithmeticExpressionT, ExpressionT, NumberT, ScalarT
@@ -1650,6 +1650,12 @@ class Derivative(Expression):
     variables: tuple[str, ...]
 
 
+SliceChildrenT: TypeAlias = (tuple[()]
+        | tuple[ExpressionT | None]
+        | tuple[ExpressionT | None, ExpressionT | None]
+        | tuple[ExpressionT | None, ExpressionT | None, ExpressionT | None])
+
+
 @expr_dataclass()
 class Slice(Expression):
     """A slice expression as in a[1:7].
@@ -1661,10 +1667,7 @@ class Slice(Expression):
     .. autoproperty:: step
     """
 
-    children: (tuple[()]
-        | tuple[ExpressionT]
-        | tuple[ExpressionT, ExpressionT]
-        | tuple[ExpressionT, ExpressionT, ExpressionT])
+    children: SliceChildrenT
 
     def __bool__(self):
         return True
diff --git a/pymbolic/typing.py b/pymbolic/typing.py
index fce695d12f077a60fc4817bd8327f263c11ac57a..c1ecaa0fae7e2b05fa9cb0b743c51ac36af21716 100644
--- a/pymbolic/typing.py
+++ b/pymbolic/typing.py
@@ -4,18 +4,23 @@
 Typing helpers
 --------------
 
-.. autodata:: BoolT
-.. autodata:: NumberT
-.. autodata:: ScalarT
-.. autodata:: ArithmeticExpressionT
+.. autoclass:: BoolT
+.. autoclass:: NumberT
+.. autoclass:: ScalarT
+.. autoclass:: ArithmeticExpressionT
 
     A narrower type alias than :class:`ExpressionT` that is returned by
     arithmetic operators, to allow continue doing arithmetic with the result
     of arithmetic.
 
-    >
+.. autoclass:: ExpressionT
 
-.. autodata:: ExpressionT
+.. currentmodule:: pymbolic.typing
+
+.. autoclass:: ArithmeticOrExpressionT
+
+    A type variable that can be either :data:`ArithmeticExpressionT`
+    or :data:`ExpressionT`.
 """
 
 from __future__ import annotations
@@ -80,6 +85,11 @@ ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"]
 
 ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple["ExpressionT", ...]]
 
+ArithmeticOrExpressionT = TypeVar(
+                "ArithmeticOrExpressionT",
+                ArithmeticExpressionT,
+                ExpressionT)
+
 
 T = TypeVar("T")
 
diff --git a/pyproject.toml b/pyproject.toml
index 6d18ba83bcbadabdecb6b7bdc9bf07f79e28c310..ea60e24c12369fb77a7caa3da5399ac1cf32bbcd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,8 +32,9 @@ dependencies = [
     "astunparse; python_version<='3.9'",
     "immutabledict",
     "pytools>=2022.1.14",
-    # for dataclass_transform, TypeAlias
-    "typing-extensions>=4",
+    # for dataclass_transform, TypeAlias, deprecated
+    "typing-extensions>=4.5",
+    "useful-types",
 ]
 
 [project.optional-dependencies]