diff --git a/pymbolic/geometric_algebra/__init__.py b/pymbolic/geometric_algebra/__init__.py index 6a6ced186be52ae49514f9fa294a1df0af9ba6c3..6eca83a1f48c34cfa160db945b3fef14af8c6f40 100644 --- a/pymbolic/geometric_algebra/__init__.py +++ b/pymbolic/geometric_algebra/__init__.py @@ -24,7 +24,7 @@ THE SOFTWARE. """ from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import dataclass from typing import Any, Generic, TypeVar, cast @@ -32,8 +32,8 @@ import numpy as np from pytools import memoize, memoize_method -from pymbolic.primitives import expr_dataclass -from pymbolic.typing import ArithmeticExpressionT +from pymbolic.primitives import expr_dataclass, is_zero +from pymbolic.typing import ArithmeticExpressionT, T __doc__ = """ @@ -1104,7 +1104,7 @@ class MultiVector(Generic[CoeffT]): # {{{ helper functions - def map(self, f): + def map(self, f: Callable[[CoeffT], CoeffT]) -> MultiVector[CoeffT]: """Return a new :class:`MultiVector` with coefficients mapped by function *f*, which takes a single coefficient as input and returns the new coefficient. @@ -1127,14 +1127,14 @@ class MultiVector(Generic[CoeffT]): # }}} -def componentwise(f, expr): +def componentwise(f: Callable[[CoeffT], CoeffT], expr: T) -> T: """Apply function *f* componentwise to object arrays and :class:`MultiVector` instances. *expr* is also allowed to be a scalar. """ if isinstance(expr, MultiVector): - return expr.map(f) + return cast(T, expr.map(f)) from pytools.obj_array import obj_array_vectorize return obj_array_vectorize(f, expr) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index abb59a197ce31d061c44a3e1ce5e944d9c590148..25a4a8cb8595132aed0a50fd1bd9cead48d3e25d 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -972,7 +972,9 @@ class IdentityMapper(Mapper[ExpressionT, P]): expr: MultiVector[ArithmeticExpressionT], *args: P.args, **kwargs: P.kwargs ) -> ExpressionT: - return expr.map(lambda ch: self.rec(ch, *args, **kwargs)) + # True fact: MultiVectors aren't expressions + return expr.map(lambda ch: cast(ArithmeticExpressionT, + self.rec(ch, *args, **kwargs))) # type: ignore[return-value] def map_common_subexpression(self, expr: p.CommonSubexpression, diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index 9aee59fcf51e0b7d98ac2e2dd70a14cb87b9443f..bdff4c2d4d35de7100953f3d2812cf1fe2686ef0 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -166,7 +166,7 @@ class EvaluationMapper(Mapper[ResultT, []], CSECachingMapperMixin): return result # type: ignore[return-value] def map_multivector(self, expr: MultiVector) -> ResultT: - return expr.map(lambda ch: self.rec(ch)) + return expr.map(lambda ch: self.rec(ch)) # type: ignore[return-value] def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> ResultT: return self.rec(expr.child)