diff --git a/doc/conf.py b/doc/conf.py
index 3b20a6ac703b72fd97dc4717b9f8e7e593cf7cd6..ba52921ccff64166c0f5f1df77d438ef061dc426 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -42,6 +42,8 @@ nitpick_ignore_regex = [
     # Understandable, because typing can't import primitives, which would be needed
     # to resolve the reference.
     ["py:class", r"ExpressionNode"],
+    ["py:class", r"_Expression"],
+    ["py:class", r"p\.AlgebraicLeaf"],
     ]
 
 
diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py
index f6da10e91048056a96fafff2267fb1091b91ce8c..c4cfb2c4f2b74d8782538d43f276337c7e140663 100644
--- a/pymbolic/algorithm.py
+++ b/pymbolic/algorithm.py
@@ -43,10 +43,14 @@ from warnings import warn
 from pytools import MovedFunctionDeprecationWrapper, memoize
 
 
-if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", None):
+if TYPE_CHECKING:
     import numpy as np
 
 
+if getattr(sys, "_BUILDING_SPHINX_DOCS", None):
+    import numpy as np  # noqa: TC002
+
+
 # {{{ integer powers
 
 def integer_power(x, n, one=1):
diff --git a/pymbolic/geometric_algebra/__init__.py b/pymbolic/geometric_algebra/__init__.py
index 86bbd03e701019715ccfee5e0ba2e9b803c96b10..38e95122251757e8fc84acff617c7bffc80895f0 100644
--- a/pymbolic/geometric_algebra/__init__.py
+++ b/pymbolic/geometric_algebra/__init__.py
@@ -567,7 +567,7 @@ class MultiVector(Generic[CoeffT]):
                     f"are supported for 'data': shape {data.shape}")
 
             dimensions, = data.shape
-            data_dict = {(i,): cast(CoeffT, xi) for i, xi in enumerate(data)}
+            data_dict = {(i,): cast("CoeffT", xi) for i, xi in enumerate(data)}
 
             if space is None:
                 space = get_euclidean_space(dimensions)
@@ -579,7 +579,7 @@ class MultiVector(Generic[CoeffT]):
         elif isinstance(data, Mapping):
             data_dict = data
         else:
-            data_dict = {0: cast(CoeffT, data)}
+            data_dict = {0: cast("CoeffT", data)}
 
         if space is None:
             raise ValueError("No 'space' provided")
@@ -595,8 +595,8 @@ class MultiVector(Generic[CoeffT]):
                 assert isinstance(basis_indices, tuple)
 
                 bits, sign = space.bits_and_sign(basis_indices)
-                new_coeff = cast(CoeffT,
-                    new_data.setdefault(bits, cast(CoeffT, 0))  # type: ignore[operator]
+                new_coeff = cast("CoeffT",
+                    new_data.setdefault(bits, cast("CoeffT", 0))  # type: ignore[operator]
                     + sign*coeff)
 
                 if is_zero(new_coeff):
@@ -604,7 +604,7 @@ class MultiVector(Generic[CoeffT]):
                 else:
                     new_data[bits] = new_coeff
         else:
-            new_data = cast(dict[int, CoeffT], data_dict)
+            new_data = cast("dict[int, CoeffT]", data_dict)
 
         # }}}
 
@@ -691,8 +691,8 @@ class MultiVector(Generic[CoeffT]):
         from pymbolic.primitives import is_zero
         new_data = {}
         for bits in all_bits:
-            new_coeff = (self.data.get(bits, cast(CoeffT, 0))
-                + other.data.get(bits, cast(CoeffT, 0)))
+            new_coeff = (self.data.get(bits, cast("CoeffT", 0))
+                + other.data.get(bits, cast("CoeffT", 0)))
 
             if not is_zero(new_coeff):
                 new_data[bits] = new_coeff
@@ -741,7 +741,7 @@ class MultiVector(Generic[CoeffT]):
                     coeff = (weight
                             * canonical_reordering_sign(sbits, obits)
                             * scoeff * ocoeff)
-                    new_coeff = new_data.setdefault(new_bits, cast(CoeffT, 0)) + coeff
+                    new_coeff = new_data.setdefault(new_bits, cast("CoeffT", 0)) + coeff
                     if is_zero(new_coeff):
                         del new_data[new_bits]
                     else:
@@ -1134,7 +1134,7 @@ def componentwise(f: Callable[[CoeffT], CoeffT], expr: T) -> T:
     """
 
     if isinstance(expr, MultiVector):
-        return cast(T, expr.map(f))
+        return cast("T", expr.map(f))
 
     from pytools.obj_array import obj_array_vectorize
     return obj_array_vectorize(f, expr)
diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py
index 085a3ca946bcb855c0fb3a54ce16383cb283915a..a66a66dd75c8a1aa7f485d5d7c196c5a2049a317 100644
--- a/pymbolic/geometric_algebra/mapper.py
+++ b/pymbolic/geometric_algebra/mapper.py
@@ -25,8 +25,7 @@ THE SOFTWARE.
 
 # This is experimental, undocumented, and could go away any second.
 # Consider yourself warned.
-from collections.abc import Set
-from typing import ClassVar
+from typing import TYPE_CHECKING, ClassVar
 
 import pymbolic.geometric_algebra.primitives as prim
 from pymbolic.geometric_algebra import MultiVector
@@ -49,7 +48,12 @@ from pymbolic.mapper.stringifier import (
     PREC_NONE,
     StringifyMapper as StringifyMapperBase,
 )
-from pymbolic.primitives import ExpressionNode
+
+
+if TYPE_CHECKING:
+    from collections.abc import Set
+
+    from pymbolic.primitives import ExpressionNode
 
 
 class IdentityMapper(IdentityMapperBase[P]):
diff --git a/pymbolic/geometric_algebra/primitives.py b/pymbolic/geometric_algebra/primitives.py
index 47de49b192de0d257b782bc6c7450f526f8f80b2..2c81d9d60ddf556f20c4ec13c593fe8092c62386 100644
--- a/pymbolic/geometric_algebra/primitives.py
+++ b/pymbolic/geometric_algebra/primitives.py
@@ -26,11 +26,15 @@ THE SOFTWARE.
 # This is experimental, undocumented, and could go away any second.
 # Consider yourself warned.
 
-from collections.abc import Hashable
-from typing import ClassVar
+from typing import TYPE_CHECKING, ClassVar
 
 from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass
-from pymbolic.typing import Expression
+
+
+if TYPE_CHECKING:
+    from collections.abc import Hashable
+
+    from pymbolic.typing import Expression
 
 
 class MultiVectorVariable(Variable):
diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py
index db4201e43e8e58b12138093faa66614e976b5252..2dc0f79686eb41d63d8a30104b8ff6d2b3d868d1 100644
--- a/pymbolic/interop/ast.py
+++ b/pymbolic/interop/ast.py
@@ -27,11 +27,14 @@ THE SOFTWARE.
 """
 
 import ast
-from typing import Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar
 
 import pymbolic.primitives as p
 from pymbolic.mapper import CachedMapper
-from pymbolic.typing import Expression
+
+
+if TYPE_CHECKING:
+    from pymbolic.typing import Expression
 
 
 __doc__ = r'''
diff --git a/pymbolic/interop/matchpy/mapper.py b/pymbolic/interop/matchpy/mapper.py
index 44f25b6e82cb6a4b1dd1837590a6556f007e09a2..43f38b33d8faacc456e2c5890bc7e0e8c57f9a05 100644
--- a/pymbolic/interop/matchpy/mapper.py
+++ b/pymbolic/interop/matchpy/mapper.py
@@ -1,9 +1,12 @@
 from __future__ import annotations
 
-from collections.abc import Callable
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
-from pymbolic.interop.matchpy import PymbolicOp
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from pymbolic.interop.matchpy import PymbolicOp
 
 
 class Mapper:
diff --git a/pymbolic/interop/matchpy/tofrom.py b/pymbolic/interop/matchpy/tofrom.py
index db5e564a1bb68cd423c5938ee31932d1062425de..0963d82e738881005ee5a4ea49b7247c68d8a0c7 100644
--- a/pymbolic/interop/matchpy/tofrom.py
+++ b/pymbolic/interop/matchpy/tofrom.py
@@ -1,8 +1,7 @@
 from __future__ import annotations
 
-from collections.abc import Callable
 from dataclasses import dataclass
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 import multiset
 import numpy as np
@@ -12,7 +11,12 @@ import pymbolic.interop.matchpy as m
 import pymbolic.primitives as p
 from pymbolic.interop.matchpy.mapper import Mapper as BaseMatchPyMapper
 from pymbolic.mapper import Mapper as BasePymMapper
-from pymbolic.typing import Scalar as PbScalar
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from pymbolic.typing import Scalar as PbScalar
 
 
 # {{{ to matchpy
diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py
index b1fc6ee321f38fab435c804e611aeca332f2ad07..06ef30e551177801fbad8a1c3e763989e4186d74 100644
--- a/pymbolic/mapper/__init__.py
+++ b/pymbolic/mapper/__init__.py
@@ -429,7 +429,7 @@ class CachedMapper(Mapper[ResultT, P]):
         method_name = getattr(expr, "mapper_method", None)
         if method_name is not None:
             method = cast(
-                Callable[Concatenate[Expression, P], ResultT] | None,
+                "Callable[Concatenate[Expression, P], ResultT] | None",
                 getattr(self, method_name, None)
                 )
             if method is not None:
@@ -973,7 +973,7 @@ class IdentityMapper(Mapper[Expression, P]):
                 *args: P.args, **kwargs: P.kwargs
             ) -> Expression:
         # True fact: MultiVectors aren't expressions
-        return expr.map(lambda ch: cast(ArithmeticExpression,
+        return expr.map(lambda ch: cast("ArithmeticExpression",
                                         self.rec(ch, *args, **kwargs)))  # type: ignore[return-value]
 
     def map_common_subexpression(self,
@@ -1012,7 +1012,7 @@ class IdentityMapper(Mapper[Expression, P]):
     def map_slice(self,
                 expr: p.Slice,
                 *args: P.args, **kwargs: P.kwargs) -> Expression:
-        children: p.SliceChildrenT = cast(p.SliceChildrenT, tuple([
+        children: p.SliceChildrenT = cast("p.SliceChildrenT", tuple([
             None if child is None else self.rec(child, *args, **kwargs)
             for child in expr.children
             ]))
diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py
index 72315f9846e47bb3bd5643bedba4311c6a99dfb8..ea223eb56517fb2cfe5816c8b3d1b655e59e0045 100644
--- a/pymbolic/mapper/coefficient.py
+++ b/pymbolic/mapper/coefficient.py
@@ -68,7 +68,7 @@ class CoefficientCollector(Mapper[CoeffsT, []]):
         for i, child_coeffs in enumerate(children_coeffs):
             if i != idx_of_child_with_vars:
                 assert len(child_coeffs) == 1
-                other_coeffs *= cast(ArithmeticExpression, child_coeffs[1])
+                other_coeffs *= cast("ArithmeticExpression", child_coeffs[1])
 
         if idx_of_child_with_vars is None:
             return {1: other_coeffs}
diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py
index c6c89eff6b2acc58beec8414a03178142b7f2c55..02cdee8656f09a36abf97c671c8fc013f3236086 100644
--- a/pymbolic/mapper/collector.py
+++ b/pymbolic/mapper/collector.py
@@ -26,14 +26,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from collections.abc import Sequence, Set
-from typing import cast
+from typing import TYPE_CHECKING, cast
 
 import pymbolic
 import pymbolic.primitives as p
 from pymbolic.mapper import IdentityMapper
-from pymbolic.mapper.dependency import DependenciesT
-from pymbolic.typing import ArithmeticExpression, Expression
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence, Set
+
+    from pymbolic.mapper.dependency import DependenciesT
+    from pymbolic.typing import ArithmeticExpression, Expression
 
 
 class TermCollector(IdentityMapper[[]]):
@@ -110,7 +114,7 @@ class TermCollector(IdentityMapper[[]]):
 
         base_exp_set = frozenset(
                 (base, exp) for base, exp in cleaned_base2exp.items())
-        return base_exp_set, cast(ArithmeticExpression,
+        return base_exp_set, cast("ArithmeticExpression",
                 self.rec(pymbolic.flattened_product(coefficients)))
 
     def map_sum(self, expr: p.Sum) -> Expression:
diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py
index 70f3b1b7abd4060042a53e0de1b659fc3f9ebcf1..93590888b7d44aaf7df4e3b47a495d7c5ff6f123 100644
--- a/pymbolic/mapper/constant_folder.py
+++ b/pymbolic/mapper/constant_folder.py
@@ -27,7 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from collections.abc import Callable
+
+from typing import TYPE_CHECKING
 
 from pymbolic.mapper import (
     CSECachingMapperMixin,
@@ -38,6 +39,10 @@ from pymbolic.primitives import Product, Sum, is_arithmetic_expression
 from pymbolic.typing import ArithmeticExpression, Expression
 
 
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+
 class ConstantFoldingMapperBase(Mapper[Expression, []]):
     def is_constant(self, expr):
         from pymbolic.mapper.dependency import DependencyMapper
diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py
index b8e33d5c2b2c5c1a3a76cbf02cf42288625e5ec6..8867c3018ce0aa2f1482639e1b1a4ffc15962add 100644
--- a/pymbolic/mapper/distributor.py
+++ b/pymbolic/mapper/distributor.py
@@ -27,14 +27,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import cast
+from typing import TYPE_CHECKING, cast
 
 import pymbolic
 import pymbolic.primitives as p
 from pymbolic.mapper import IdentityMapper
 from pymbolic.mapper.collector import TermCollector
 from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
-from pymbolic.typing import ArithmeticExpression, Expression
+
+
+if TYPE_CHECKING:
+    from pymbolic.typing import ArithmeticExpression, Expression
 
 
 class DistributeMapper(IdentityMapper[[]]):
@@ -118,7 +121,7 @@ class DistributeMapper(IdentityMapper[[]]):
         newbase = self.rec(expr.base)
         if isinstance(newbase, p.Product):
             return self.rec(pymbolic.flattened_product([
-                cast(ArithmeticExpression, child)**expr.exponent
+                cast("ArithmeticExpression", child)**expr.exponent
                     for child in newbase.children
                 ]))
 
diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py
index d38a395c8fd8de60a2f07da6b3986fa9e3f18d5c..1dab45cbb0c3eaff234927830897d61883df27b1 100644
--- a/pymbolic/mapper/evaluator.py
+++ b/pymbolic/mapper/evaluator.py
@@ -34,19 +34,20 @@ THE SOFTWARE.
 """
 
 import operator as op
-from collections.abc import Mapping
 from functools import reduce
 from typing import TYPE_CHECKING, cast
 
-import pymbolic.primitives as p
 from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper, ResultT
-from pymbolic.typing import Expression
 
 
 if TYPE_CHECKING:
+    from collections.abc import Mapping
+
     import numpy as np
 
+    import pymbolic.primitives as p
     from pymbolic.geometric_algebra import MultiVector
+    from pymbolic.typing import Expression
 
 
 class UnknownVariableError(Exception):
@@ -82,7 +83,7 @@ class EvaluationMapper(Mapper[ResultT, []], CSECachingMapperMixin):
         self.context = context
 
     def map_constant(self, expr: object) -> ResultT:
-        return cast(ResultT, expr)
+        return cast("ResultT", expr)
 
     def map_variable(self, expr: p.Variable) -> ResultT:
         try:
diff --git a/pymbolic/mapper/flattener.py b/pymbolic/mapper/flattener.py
index e5972c0b575c87428bfbe669c9ed8e48e6550084..316c1e83be8c6d856befe339f5caeef9c209a348 100644
--- a/pymbolic/mapper/flattener.py
+++ b/pymbolic/mapper/flattener.py
@@ -31,15 +31,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import cast
+from typing import TYPE_CHECKING, cast
 
 import pymbolic.primitives as p
 from pymbolic.mapper import IdentityMapper
-from pymbolic.typing import (
-    ArithmeticExpression,
-    ArithmeticOrExpressionT,
-    Expression,
-)
+
+
+if TYPE_CHECKING:
+    from pymbolic.typing import (
+        ArithmeticExpression,
+        ArithmeticOrExpressionT,
+        Expression,
+    )
 
 
 class FlattenMapper(IdentityMapper[[]]):
@@ -68,13 +71,13 @@ class FlattenMapper(IdentityMapper[[]]):
     def map_sum(self, expr: p.Sum) -> Expression:
         from pymbolic.primitives import flattened_sum
         return flattened_sum([
-                             cast(ArithmeticExpression, self.rec(ch))
+                             cast("ArithmeticExpression", self.rec(ch))
                              for ch in expr.children])
 
     def map_product(self, expr: p.Product) -> Expression:
         from pymbolic.primitives import flattened_product
         return flattened_product([
-                                 cast(ArithmeticExpression, self.rec(ch))
+                                 cast("ArithmeticExpression", self.rec(ch))
                                  for ch in expr.children])
 
     def map_quotient(self, expr: p.Quotient) -> Expression:
@@ -123,4 +126,4 @@ class FlattenMapper(IdentityMapper[[]]):
 
 
 def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT:
-    return cast(ArithmeticOrExpressionT, FlattenMapper()(expr))
+    return cast("ArithmeticOrExpressionT", FlattenMapper()(expr))
diff --git a/pymbolic/mapper/optimize.py b/pymbolic/mapper/optimize.py
index a07757ccdb351d740c2a76fb8741a4b618bc882c..27500722d6791f0be93002e365c88d65c4d62449 100644
--- a/pymbolic/mapper/optimize.py
+++ b/pymbolic/mapper/optimize.py
@@ -24,9 +24,12 @@ THE SOFTWARE.
 """
 
 import ast
-from collections.abc import Callable, Iterable, MutableMapping
 from functools import cached_property, lru_cache
-from typing import TextIO, TypeVar, cast
+from typing import TYPE_CHECKING, TextIO, TypeVar, cast
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Iterable, MutableMapping
 
 
 # This machinery applies AST rewriting to the mapper in a mildly brutal
@@ -130,7 +133,7 @@ class _RecInliner(ast.NodeTransformer):
         self.inline_cache = inline_cache
 
     def visit_Call(self, node: ast.Call) -> ast.AST:  # noqa: N802
-        node = cast(ast.Call, self.generic_visit(node))
+        node = cast("ast.Call", self.generic_visit(node))
 
         result_expr: ast.expr = node
 
@@ -397,7 +400,7 @@ def optimize_mapper(
             "exec"),
              compile_dict)
 
-        return cast(type, compile_dict[cls.__name__])
+        return cast("type", compile_dict[cls.__name__])
 
     return wrapper
 
diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py
index 7e14aabfabd8582a86693bf4aa43ead6f4db6615..72a01bdd119809e04b988ba792c389a1c6b1641c 100644
--- a/pymbolic/mapper/stringifier.py
+++ b/pymbolic/mapper/stringifier.py
@@ -22,7 +22,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
-from collections.abc import Sequence
 from typing import TYPE_CHECKING, ClassVar, Concatenate
 from warnings import warn
 
@@ -30,13 +29,15 @@ from typing_extensions import deprecated
 
 import pymbolic.primitives as p
 from pymbolic.mapper import CachedMapper, Mapper, P
-from pymbolic.typing import Expression
 
 
 if TYPE_CHECKING:
+    from collections.abc import Sequence
+
     import numpy as np
 
     from pymbolic.geometric_algebra import MultiVector
+    from pymbolic.typing import Expression
 
 
 __doc__ = """
diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py
index 7948fca7500f1083e7066e94467c9c22e8320104..a755d295e3a925b5610789e88fa7597fbf6eef9d 100644
--- a/pymbolic/mapper/substitutor.py
+++ b/pymbolic/mapper/substitutor.py
@@ -38,17 +38,24 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
-from collections.abc import Callable
+import sys
 from typing import TYPE_CHECKING, Any
 
 from pymbolic.mapper import CachedIdentityMapper, IdentityMapper
-from pymbolic.primitives import AlgebraicLeaf
-from pymbolic.typing import Expression
 
 
 if TYPE_CHECKING:
+    from collections.abc import Callable
+
     from useful_types import SupportsGetItem, SupportsItems
 
+    from pymbolic.primitives import AlgebraicLeaf
+    from pymbolic.typing import Expression
+
+
+if getattr(sys, "_BUILDING_SPHINX_DOCS", None):
+    from collections.abc import Callable  # noqa: TC003
+
 
 class SubstitutionMapper(IdentityMapper[[]]):
     def __init__(
diff --git a/pymbolic/parser.py b/pymbolic/parser.py
index 8a12e4714b50e3202dc9cae4713c03912ec4a060..97c9c9f330fd98a831fed6f9dcb539ed5d5b9277 100644
--- a/pymbolic/parser.py
+++ b/pymbolic/parser.py
@@ -1,7 +1,5 @@
 from __future__ import annotations
 
-from pymbolic.typing import Expression
-
 
 __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
 
@@ -27,7 +25,7 @@ THE SOFTWARE.
 
 from collections.abc import Sequence
 from sys import intern
-from typing import ClassVar, TypeAlias
+from typing import TYPE_CHECKING, ClassVar, TypeAlias
 
 from immutabledict import immutabledict
 
@@ -35,6 +33,10 @@ import pytools.lex
 from pytools import memoize_method
 
 
+if TYPE_CHECKING:
+    from pymbolic.typing import Expression
+
+
 _imaginary = intern("imaginary")
 _float = intern("float")
 _int = intern("int")
diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index caf0dbfb8eebdda4d61df1579075f63c7c763a2f..6a11449603e8f25322d67918c48e376c6ae180d3 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -24,7 +24,6 @@ THE SOFTWARE.
 """
 
 import re
-from collections.abc import Callable, Iterable, Mapping
 from dataclasses import dataclass, fields
 from functools import partial
 from sys import intern
@@ -50,6 +49,8 @@ from .typing import ArithmeticExpression, Expression as _Expression, Number, Sca
 
 
 if TYPE_CHECKING:
+    from collections.abc import Callable, Iterable, Mapping
+
     from _typeshed import DataclassInstance
 
 
@@ -1058,7 +1059,7 @@ def _augment_expression_dataclass(
 
     # {{{ assign mapper_method
 
-    mm_cls = cast(type[_HasMapperMethod], cls)
+    mm_cls = cast("type[_HasMapperMethod]", cls)
 
     snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower()
     default_mapper_method_name = f"map_{snake_clsname}"
@@ -1793,7 +1794,7 @@ def flattened_sum(terms: Iterable[ArithmeticExpression]) -> ArithmeticExpression
             continue
 
         if isinstance(item, Sum):
-            ch = cast(tuple[ArithmeticExpression], item.children)
+            ch = cast("tuple[ArithmeticExpression]", item.children)
             queue.extend(ch)
         else:
             done.append(item)
@@ -1835,7 +1836,7 @@ def flattened_product(terms: Iterable[ArithmeticExpression]) -> ArithmeticExpres
             continue
 
         if isinstance(item, Product):
-            ch = cast(tuple[ArithmeticExpression], item.children)
+            ch = cast("tuple[ArithmeticExpression]", item.children)
             queue.extend(ch)
         else:
             done.append(item)
diff --git a/pyproject.toml b/pyproject.toml
index 90434e39a2be36ae3b315bf94355e7db0cb1eb9a..b22018e42f9e6f8fa788bb9639da579585886536 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -68,6 +68,7 @@ extend-select = [
     "RUF",  # ruff
     "UP",   # pyupgrade
     "W",    # pycodestyle
+    "TC",   # type checking
 ]
 extend-ignore = [
     "C409", # remove comprehension within tuple call
diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index be992d7fe556742f2d8b9b3f135ae86b5f3d3691..29fda763f8ebd2d317826c6fb65a18513caa4c03 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -3,7 +3,6 @@ from __future__ import annotations
 from pymbolic.mapper.evaluator import evaluate_kw
 from pymbolic.mapper.flattener import FlattenMapper
 from pymbolic.mapper.stringifier import StringifyMapper
-from pymbolic.typing import Expression
 
 
 __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
@@ -30,6 +29,7 @@ THE SOFTWARE.
 
 import logging
 from functools import reduce
+from typing import TYPE_CHECKING
 
 import pytest
 from testlib import generate_random_expression
@@ -42,6 +42,10 @@ from pymbolic.mapper import IdentityMapper, WalkMapper
 from pymbolic.mapper.dependency import CachedDependencyMapper, DependencyMapper
 
 
+if TYPE_CHECKING:
+    from pymbolic.typing import Expression
+
+
 logger = logging.getLogger(__name__)