diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index bf5b64ac6428bd013bad39a13c6eebe1a314eecc..d2508f9addc96868edcbcd568d72febf40bf2e11 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -25,16 +25,19 @@ THE SOFTWARE. # This is experimental, undocumented, and could go away any second. # Consider yourself warned. - +from collections.abc import Set from typing import ClassVar import pymbolic.geometric_algebra.primitives as prim from pymbolic.geometric_algebra import MultiVector from pymbolic.mapper import ( CachedMapper, + CollectedT, Collector as CollectorBase, CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, + P, + ResultT, WalkMapper as WalkMapperBase, ) from pymbolic.mapper.constant_folder import ( @@ -46,50 +49,66 @@ from pymbolic.mapper.stringifier import ( PREC_NONE, StringifyMapper as StringifyMapperBase, ) +from pymbolic.primitives import Expression -class IdentityMapper(IdentityMapperBase): - def map_multivector_variable(self, expr): +class IdentityMapper(IdentityMapperBase[P]): + def map_nabla( + self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression: return expr - map_nabla = map_multivector_variable - map_nabla_component = map_multivector_variable + def map_nabla_component(self, + expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression: + return expr - def map_derivative_source(self, expr): - operand = self.rec(expr.operand) + def map_derivative_source(self, + expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs + ) -> Expression: + operand = self.rec(expr.operand, *args, **kwargs) if operand is expr.operand: return expr return type(expr)(operand, expr.nabla_id) -class CombineMapper(CombineMapperBase): - def map_derivative_source(self, expr): - return self.rec(expr.operand) +class CombineMapper(CombineMapperBase[ResultT, P]): + def map_derivative_source( + self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.rec(expr.operand, *args, **kwargs) -class Collector(CollectorBase): - def map_nabla(self, expr): +class Collector(CollectorBase[CollectedT, P]): + def map_nabla(self, + expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs + ) -> Set[CollectedT]: return set() - map_nabla_component = map_nabla + def map_nabla_component(self, + expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs + ) -> Set[CollectedT]: + return set() -class WalkMapper(WalkMapperBase): - def map_nabla(self, expr, *args): - self.visit(expr, *args) - self.post_visit(expr) +class WalkMapper(WalkMapperBase[P]): + def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) - def map_nabla_component(self, expr, *args): - self.visit(expr, *args) - self.post_visit(expr) + def map_nabla_component( + self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs + ) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) - def map_derivative_source(self, expr, *args): - if not self.visit(expr, *args): + def map_derivative_source( + self, expr, *args: P.args, **kwargs: P.kwargs + ) -> None: + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.operand) - self.post_visit(expr) + self.rec(expr.operand, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) class EvaluationMapper(EvaluationMapperBase): @@ -106,7 +125,7 @@ class EvaluationMapper(EvaluationMapperBase): return type(expr)(operand, expr.nabla_id) -class StringifyMapper(StringifyMapperBase): +class StringifyMapper(StringifyMapperBase[[]]): AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"} def map_nabla(self, expr, enclosing_prec):