diff --git a/pytato/array.py b/pytato/array.py index 1391103fb33823cffd9f9f897f10e4d803788b01..046955f3181f0049670661ff3e359fdc37d79b46 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -88,6 +88,7 @@ Node constructors such as :class:`Placeholder.__init__` and import numpy as np import pymbolic.primitives as prim import pytato.scalar_expr as scalar_expr +from pytato.scalar_expr import ScalarExpression from dataclasses import dataclass from pytools import is_single_valued @@ -247,16 +248,15 @@ TagsType = FrozenSet[Tag] # {{{ shape -ShapeComponentType = Union[int, prim.Expression] -ShapeType = Tuple[ShapeComponentType, ...] +ShapeType = Tuple[ScalarExpression, ...] ConvertibleToShapeComponent = Union[int, prim.Expression, str] ConvertibleToShape = Union[ str, - prim.Expression, + ScalarExpression, Tuple[ConvertibleToShapeComponent, ...]] -def _check_identifier(s, ns: Optional[Namespace] = None): +def _check_identifier(s: str, ns: Optional[Namespace] = None) -> bool: if not str.isidentifier(s): raise ValueError(f"'{s}' is not a valid identifier") @@ -270,8 +270,9 @@ def _check_identifier(s, ns: Optional[Namespace] = None): class _ShapeChecker(scalar_expr.WalkMapper): def __init__(self, ns: Optional[Namespace] = None): super().__init__() + self.ns = ns - def map_variable(self, expr): + def map_variable(self, expr: prim.Variable) -> None: _check_identifier(expr.name, self.ns) super().map_variable(expr) @@ -286,7 +287,8 @@ def normalize_shape( """ from pytato.scalar_expr import parse - def nnormalize_shape_component(s): + def normalize_shape_component( + s: ConvertibleToShapeComponent) -> ScalarExpression: if isinstance(s, str): s = parse(s) @@ -303,10 +305,11 @@ def normalize_shape( if isinstance(shape, str): shape = parse(shape) - if isinstance(shape, (int, prim.Expression)): + from numbers import Number + if isinstance(shape, (Number, prim.Expression)): shape = (shape,) - return tuple(nnormalize_shape_component(s) for s in shape) + return tuple(normalize_shape_component(s) for s in shape) # }}} diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index aa430eecbec5f644b34b59a8740203058f8500a8..c6759845b7b83a53c4e253e481bcdf96f7a6f031 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -25,9 +25,16 @@ THE SOFTWARE. """ from pymbolic.mapper import WalkMapper as WalkMapperBase +import pymbolic.primitives as prim +from numbers import Number +from typing import Union -def parse(s): + +ScalarExpression = Union[Number, prim.Expression] + + +def parse(s: str) -> ScalarExpression: from pymbolic.parser import Parser return Parser()(s) diff --git a/setup.cfg b/setup.cfg index d16b02e0d810d3f4d90efbdc494bd2e6e2422508..b975cd3a4f878d7de702bd97d6a86d6e65b9ab4b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,9 @@ ignore = E126,E127,E128,E123,E226,E241,E242,E265,N802,W503,E402,N814,N817,W504 max-line-length=85 +[mypy-pytato.scalar_expr] +disallow_subclassing_any = False + [mypy-pymbolic] ignore_missing_imports = True