diff --git a/pytato/array.py b/pytato/array.py
index 1391103fb33823cffd9f9f897f10e4d803788b01..046955f3181f0049670661ff3e359fdc37d79b46 100644
--- a/pytato/array.py
+++ b/pytato/array.py
@@ -88,6 +88,7 @@ Node constructors such as :class:`Placeholder.__init__` and
 import numpy as np
 import pymbolic.primitives as prim
 import pytato.scalar_expr as scalar_expr
+from pytato.scalar_expr import ScalarExpression
 
 from dataclasses import dataclass
 from pytools import is_single_valued
@@ -247,16 +248,15 @@ TagsType = FrozenSet[Tag]
 
 # {{{ shape
 
-ShapeComponentType = Union[int, prim.Expression]
-ShapeType = Tuple[ShapeComponentType, ...]
+ShapeType = Tuple[ScalarExpression, ...]
 ConvertibleToShapeComponent = Union[int, prim.Expression, str]
 ConvertibleToShape = Union[
         str,
-        prim.Expression,
+        ScalarExpression,
         Tuple[ConvertibleToShapeComponent, ...]]
 
 
-def _check_identifier(s, ns: Optional[Namespace] = None):
+def _check_identifier(s: str, ns: Optional[Namespace] = None) -> bool:
     if not str.isidentifier(s):
         raise ValueError(f"'{s}' is not a valid identifier")
 
@@ -270,8 +270,9 @@ def _check_identifier(s, ns: Optional[Namespace] = None):
 class _ShapeChecker(scalar_expr.WalkMapper):
     def __init__(self, ns: Optional[Namespace] = None):
         super().__init__()
+        self.ns = ns
 
-    def map_variable(self, expr):
+    def map_variable(self, expr: prim.Variable) -> None:
         _check_identifier(expr.name, self.ns)
         super().map_variable(expr)
 
@@ -286,7 +287,8 @@ def normalize_shape(
     """
     from pytato.scalar_expr import parse
 
-    def nnormalize_shape_component(s):
+    def normalize_shape_component(
+            s: ConvertibleToShapeComponent) -> ScalarExpression:
         if isinstance(s, str):
             s = parse(s)
 
@@ -303,10 +305,11 @@ def normalize_shape(
     if isinstance(shape, str):
         shape = parse(shape)
 
-    if isinstance(shape, (int, prim.Expression)):
+    from numbers import Number
+    if isinstance(shape, (Number, prim.Expression)):
         shape = (shape,)
 
-    return tuple(nnormalize_shape_component(s) for s in shape)
+    return tuple(normalize_shape_component(s) for s in shape)
 
 # }}}
 
diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py
index aa430eecbec5f644b34b59a8740203058f8500a8..c6759845b7b83a53c4e253e481bcdf96f7a6f031 100644
--- a/pytato/scalar_expr.py
+++ b/pytato/scalar_expr.py
@@ -25,9 +25,16 @@ THE SOFTWARE.
 """
 
 from pymbolic.mapper import WalkMapper as WalkMapperBase
+import pymbolic.primitives as prim
 
+from numbers import Number
+from typing import Union
 
-def parse(s):
+
+ScalarExpression = Union[Number, prim.Expression]
+
+
+def parse(s: str) -> ScalarExpression:
     from pymbolic.parser import Parser
     return Parser()(s)
 
diff --git a/setup.cfg b/setup.cfg
index d16b02e0d810d3f4d90efbdc494bd2e6e2422508..b975cd3a4f878d7de702bd97d6a86d6e65b9ab4b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,6 +2,9 @@
 ignore = E126,E127,E128,E123,E226,E241,E242,E265,N802,W503,E402,N814,N817,W504
 max-line-length=85
 
+[mypy-pytato.scalar_expr]
+disallow_subclassing_any = False
+
 [mypy-pymbolic]
 ignore_missing_imports = True