From a8db8e6c81251402685bf1bd04447b2e8d0d0f15 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 7 Oct 2024 13:13:52 -0500 Subject: [PATCH] Type the mappers Co-authored-by: Alexandru Fikl <alexfikl@gmail.com> --- pymbolic/interop/ast.py | 8 +- pymbolic/mapper/__init__.py | 913 +++++++++++++++++------ pymbolic/mapper/coefficient.py | 52 +- pymbolic/mapper/collector.py | 46 +- pymbolic/mapper/constant_folder.py | 48 +- pymbolic/mapper/dependency.py | 118 +-- pymbolic/mapper/distributor.py | 34 +- pymbolic/mapper/evaluator.py | 99 ++- pymbolic/mapper/flattener.py | 35 +- pymbolic/mapper/optimize.py | 10 +- pymbolic/mapper/stringifier.py | 1080 +++++++++++++++++++--------- pymbolic/mapper/substitutor.py | 61 +- pymbolic/primitives.py | 13 +- pymbolic/typing.py | 22 +- pyproject.toml | 5 +- 15 files changed, 1797 insertions(+), 747 deletions(-) diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index 743edf6..4b6cc1c 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -31,7 +31,7 @@ from typing import Any, ClassVar import pymbolic.primitives as p from pymbolic.mapper import CachedMapper -from pymbolic.typing import ExpressionT, ScalarT +from pymbolic.typing import ExpressionT __doc__ = r''' @@ -263,7 +263,7 @@ class ASTToPymbolic(ASTMapper): # {{{ PymbolicToASTMapper -class PymbolicToASTMapper(CachedMapper): +class PymbolicToASTMapper(CachedMapper[ast.expr, []]): def map_variable(self, expr) -> ast.expr: return ast.Name(id=expr.name) @@ -283,7 +283,7 @@ class PymbolicToASTMapper(CachedMapper): def map_product(self, expr: p.Product) -> ast.expr: return self._map_multi_children_op(expr.children, ast.Mult()) - def map_constant(self, expr: ScalarT) -> ast.expr: + def map_constant(self, expr: object) -> ast.expr: if isinstance(expr, bool): return ast.NameConstant(expr) else: @@ -393,7 +393,7 @@ class PymbolicToASTMapper(CachedMapper): raise NotImplementedError("Non-float nan not implemented") def map_slice(self, expr: p.Slice) -> ast.expr: - return ast.Slice(*[self.rec(child) + return ast.Slice(*[None if child is None else self.rec(child) for child in expr.children]) def map_numpy_array(self, expr) -> ast.expr: diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 22315f9..65dbfb9 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -24,17 +24,41 @@ THE SOFTWARE. """ from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Mapping +from typing import ( + TYPE_CHECKING, + AbstractSet, + Callable, + Generic, + Hashable, + Iterable, + TypeVar, + cast, +) +from warnings import warn from immutabledict import immutabledict +from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeIs -import pymbolic.primitives as primitives +import pymbolic.primitives as p +from pymbolic.typing import ArithmeticExpressionT, ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.geometric_algebra import MultiVector + from pymbolic.rational import Rational __doc__ = """ Basic dispatch -------------- +.. class:: ResultT + + A type variable for the result returned by a :class:`Mapper`. + .. autoclass:: Mapper .. automethod:: __call__ @@ -96,14 +120,20 @@ Base classes for mappers with memoization support """ -try: - import numpy +if TYPE_CHECKING: + import numpy as np - def is_numpy_array(val): - return isinstance(val, numpy.ndarray) -except ImportError: - def is_numpy_array(ary): - return False + def is_numpy_array(val) -> TypeIs[np.ndarray]: + return isinstance(val, np.ndarray) +else: + try: + import numpy as np + + def is_numpy_array(val): + return isinstance(val, np.ndarray) + except ImportError: + def is_numpy_array(ary): + return False class UnsupportedExpressionError(ValueError): @@ -112,15 +142,28 @@ class UnsupportedExpressionError(ValueError): # {{{ mapper base -class Mapper: +ResultT = TypeVar("ResultT") + +# This ParamSpec could be marked contravariant (just like Callable is contravariant +# in its arguments). As of mypy 1.14/Py3.13 (Nov 2024), mypy complains of as-yet +# undefined semantics, so it's probably too soon. +P = ParamSpec("P") + + +class Mapper(Generic[ResultT, P]): """A visitor for trees of :class:`pymbolic.Expression` subclasses. Each expression-derived object is dispatched to the method named by the :attr:`pymbolic.Expression.mapper_method` attribute and if not found, the methods named by the class attribute *mapper_method* in the method resolution order of the object. + + ..automethod:: handle_unsupported_expression + ..automethod:: __call__ + ..automethod:: rec """ - def handle_unsupported_expression(self, expr, *args, **kwargs): + def handle_unsupported_expression(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method that is invoked for :class:`pymbolic.Expression` subclasses for which a mapper method does not exist in this mapper. @@ -130,7 +173,8 @@ class Mapper: "{} cannot handle expressions of type {}".format( type(self), type(expr))) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Dispatch *expr* to its corresponding mapper method. Pass on ``*args`` and ``**kwargs`` unmodified. @@ -148,7 +192,7 @@ class Mapper: result = method(expr, *args, **kwargs) return result - if isinstance(expr, primitives.Expression): + if isinstance(expr, p.Expression): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) if method_name: @@ -162,8 +206,9 @@ class Mapper: rec = __call__ - def rec_fallback(self, expr, *args, **kwargs): - if isinstance(expr, primitives.Expression): + def rec_fallback(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: + if isinstance(expr, p.Expression): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) if method_name: @@ -175,76 +220,188 @@ class Mapper: else: return self.map_foreign(expr, *args, **kwargs) - def map_algebraic_leaf(self, expr, *args, **kwargs): + def map_algebraic_leaf(self, + expr: p.AlgebraicLeaf, + *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: p.Variable, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_rational(self, expr, *args, **kwargs): - return self.map_quotient(expr, *args, **kwargs) + def map_if(self, + expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError - def map_quotient(self, expr, *args, **kwargs): + def map_sum(self, + expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_constant(self, expr, *args, **kwargs): + def map_product(self, + expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_rational(self, + expr: Rational, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_quotient(self, + expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_comparison(self, + expr: p.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_tuple(self, + expr: tuple[ExpressionT, ...], + *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: raise NotImplementedError - def map_list(self, expr, *args, **kwargs): + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: raise NotImplementedError - def map_tuple(self, expr, *args, **kwargs): + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: raise NotImplementedError - def map_numpy_array(self, expr, *args, **kwargs): + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: raise NotImplementedError - def map_nan(self, expr, *args, **kwargs): + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_nan(self, + expr: p.NaN, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_foreign(self, expr, *args, **kwargs): + def map_foreign(self, + expr: object, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: """Mapper method dispatch for non-:mod:`pymbolic` objects.""" - if isinstance(expr, primitives.VALID_CONSTANT_CLASSES): + if isinstance(expr, p.VALID_CONSTANT_CLASSES): return self.map_constant(expr, *args, **kwargs) elif is_numpy_array(expr): return self.map_numpy_array(expr, *args, **kwargs) - elif isinstance(expr, list): - return self.map_list(expr, *args, **kwargs) elif isinstance(expr, tuple): return self.map_tuple(expr, *args, **kwargs) + elif isinstance(expr, list): + warn("List found in expression graph. " + "This is deprecated and will stop working in 2025. " + "Use tuples instead.", DeprecationWarning, stacklevel=2 + ) + return self.map_list(expr, *args, **kwargs) else: raise ValueError( "{} encountered invalid foreign object: {}".format( self.__class__, repr(expr))) -_NOT_IN_CACHE = object() +class _NotInCache: + pass + +CacheKeyT: TypeAlias = Hashable -class CachedMapper(Mapper): + +class CachedMapper(Mapper[ResultT, P]): """ A mapper that memoizes the mapped result for the expressions traversed. .. automethod:: get_cache_key """ - def __init__(self): - self._cache: dict[Any, Any] = {} + def __init__(self) -> None: + self._cache: dict[CacheKeyT, ResultT] = {} Mapper.__init__(self) - def get_cache_key(self, expr, *args, **kwargs): + def get_cache_key(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> CacheKeyT: """ Returns the key corresponding to which the result of a mapper method is stored in the cache. @@ -260,16 +417,23 @@ class CachedMapper(Mapper): # and "4 == 4.0", but their traversal results cannot be re-used. return (type(expr), expr, args, immutabledict(kwargs)) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: result = self._cache.get( (cache_key := self.get_cache_key(expr, *args, **kwargs)), - _NOT_IN_CACHE) - if result is not _NOT_IN_CACHE: + _NotInCache) + if not isinstance(result, type): return result method_name = getattr(expr, "mapper_method", None) if method_name is not None: - method = getattr(self, method_name, None) + method = cast( + Callable[Concatenate[ExpressionT, P], ResultT], + getattr(self, method_name, None) + ) if method is not None: result = method(expr, *args, **kwargs) self._cache[cache_key] = result @@ -286,7 +450,7 @@ class CachedMapper(Mapper): # {{{ combine mapper -class CombineMapper(Mapper): +class CombineMapper(Mapper[ResultT, P]): """A mapper whose goal it is to *combine* all branches of the expression tree into one final result. The default implementation of all mapper methods simply recurse (:meth:`Mapper.rec`) on all branches emanating from @@ -304,16 +468,19 @@ class CombineMapper(Mapper): :class:`pymbolic.mapper.dependency.DependencyMapper` is another example. """ - def combine(self, values): + def combine(self, values: Iterable[ResultT]) -> ResultT: raise NotImplementedError - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters] )) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters], @@ -321,87 +488,141 @@ class CombineMapper(Mapper): for child in expr.kw_parameters.values()] )) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine( [self.rec(expr.aggregate, *args, **kwargs), self.rec(expr.index, *args, **kwargs)]) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.aggregate, *args, **kwargs) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr.children) - map_product = map_sum + def map_product(self, + expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_quotient(self, expr, *args, **kwargs): + def map_quotient(self, + expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.numerator, *args, **kwargs), self.rec(expr.denominator, *args, **kwargs))) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) + + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) - def map_power(self, expr, *args, **kwargs): + def map_power(self, + expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.base, *args, **kwargs), self.rec(expr.exponent, *args, **kwargs))) - def map_polynomial(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( - self.rec(expr.base, *args, **kwargs), - *[self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data] - )) + self.rec(expr.shiftee, *args, **kwargs), + self.rec(expr.shift, *args, **kwargs))) - def map_left_shift(self, expr, *args, **kwargs): + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.shiftee, *args, **kwargs), self.rec(expr.shift, *args, **kwargs))) - map_right_shift = map_left_shift + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.rec(expr.child, *args, **kwargs) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.child, *args, **kwargs) - map_bitwise_or = map_sum - map_bitwise_xor = map_sum - map_bitwise_and = map_sum - map_logical_not = map_bitwise_not - map_logical_and = map_sum - map_logical_or = map_sum + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_comparison(self, expr, *args, **kwargs): + def map_comparison(self, + expr: p.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.left, *args, **kwargs), self.rec(expr.right, *args, **kwargs))) - map_max = map_sum - map_min = map_sum + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_list(self, expr, *args, **kwargs): + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr) - map_tuple = map_list + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) for child in expr) - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat) - def map_multivector(self, expr, *args, **kwargs): + def map_multivector(self, + expr: MultiVector[ArithmeticExpressionT], + *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine( self.rec(coeff, *args, **kwargs) for bits, coeff in expr.data.items()) - def map_common_subexpression(self, expr, *args, **kwargs): + def map_common_subexpression(self, + expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.rec(expr.child, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): - return self.combine([ - self.rec(expr.criterion, *args, **kwargs), - self.rec(expr.then, *args, **kwargs), - self.rec(expr.else_, *args, **kwargs)]) - - def map_if(self, expr, *args, **kwargs): + def map_if(self, + expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine([ self.rec(expr.condition, *args, **kwargs), self.rec(expr.then, *args, **kwargs), @@ -416,7 +637,10 @@ class CachedCombineMapper(CachedMapper, CombineMapper): # {{{ collector -class Collector(CombineMapper): +CollectedT = TypeVar("CollectedT") + + +class Collector(CombineMapper[AbstractSet[CollectedT], P]): """A subclass of :class:`CombineMapper` for the common purpose of collecting data derived from an expression in a set that gets 'unioned' across children at each non-leaf node in the expression tree. @@ -426,19 +650,36 @@ class Collector(CombineMapper): .. versionadded:: 2014.3 """ - def combine(self, values): + def combine(self, + values: Iterable[AbstractSet[CollectedT]] + ) -> AbstractSet[CollectedT]: import operator from functools import reduce return reduce(operator.or_, values, set()) - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, expr: object, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() - map_variable = map_constant - map_wildcard = map_constant - map_dot_wildcard = map_constant - map_star_wildcard = map_constant - map_function_symbol = map_constant + def map_variable(self, expr: p.Variable, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_wildcard(self, expr: p.Wildcard, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_dot_wildcard(self, expr: p.DotWildcard, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_star_wildcard(self, expr: p.StarWildcard, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_function_symbol(self, expr: p.FunctionSymbol, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() class CachedCollector(CachedMapper, Collector): @@ -449,34 +690,59 @@ class CachedCollector(CachedMapper, Collector): # {{{ identity mapper -class IdentityMapper(Mapper): +class IdentityMapper(Mapper[ExpressionT, P]): """A :class:`Mapper` whose default mapper methods make a deep copy of each subexpression. See :ref:`custom-manipulation` for an example of the manipulations that can be implemented this way. + + .. automethod:: rec_arith """ - def map_constant(self, expr, *args, **kwargs): + + def rec_arith(self, + expr: ArithmeticExpressionT, *args: P.args, **kwargs: P.kwargs + ) -> ArithmeticExpressionT: + res = self.rec(expr, *args, **kwargs) + assert p.is_arithmetic_expression(res) + return res + + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: # leaf -- no need to rebuild + assert p.is_valid_operand(expr) return expr - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: p.Variable, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: # leaf -- no need to rebuild return expr - def map_wildcard(self, expr, *args, **kwargs): + def map_wildcard(self, + expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_dot_wildcard(self, expr, *args, **kwargs): + def map_dot_wildcard(self, + expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_star_wildcard(self, expr, *args, **kwargs): + def map_star_wildcard(self, + expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_function_symbol(self, expr, *args, **kwargs): + def map_function_symbol(self, + expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: p.Call, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters @@ -488,12 +754,14 @@ class IdentityMapper(Mapper): return type(expr)(function, parameters) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters ]) - kw_parameters = immutabledict({ + kw_parameters: Mapping[str, ExpressionT] = immutabledict({ key: self.rec(val, *args, **kwargs) for key, val in expr.kw_parameters.items()}) @@ -505,20 +773,26 @@ class IdentityMapper(Mapper): return expr return type(expr)(function, parameters, kw_parameters) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: aggregate = self.rec(expr.aggregate, *args, **kwargs) index = self.rec(expr.index, *args, **kwargs) if aggregate is expr.aggregate and index is expr.index: return expr return type(expr)(aggregate, index) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: aggregate = self.rec(expr.aggregate, *args, **kwargs) if aggregate is expr.aggregate: return expr return type(expr)(aggregate, expr.name) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: p.Sum, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children)): @@ -526,41 +800,81 @@ class IdentityMapper(Mapper): return type(expr)(tuple(children)) - map_product = map_sum + def map_product(self, + expr: p.Product, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_quotient(self, + expr: p.Quotient, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec_arith(expr.numerator, *args, **kwargs) + denominator = self.rec_arith(expr.denominator, *args, **kwargs) + if numerator is expr.numerator and denominator is expr.denominator: + return expr + return expr.__class__(numerator, denominator) - def map_quotient(self, expr, *args, **kwargs): - numerator = self.rec(expr.numerator, *args, **kwargs) - denominator = self.rec(expr.denominator, *args, **kwargs) + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec_arith(expr.numerator, *args, **kwargs) + denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: return expr return expr.__class__(numerator, denominator) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec_arith(expr.numerator, *args, **kwargs) + denominator = self.rec_arith(expr.denominator, *args, **kwargs) + if numerator is expr.numerator and denominator is expr.denominator: + return expr + return expr.__class__(numerator, denominator) - def map_power(self, expr, *args, **kwargs): - base = self.rec(expr.base, *args, **kwargs) - exponent = self.rec(expr.exponent, *args, **kwargs) + def map_power(self, + expr: p.Power, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + base = self.rec_arith(expr.base, *args, **kwargs) + exponent = self.rec_arith(expr.exponent, *args, **kwargs) if base is expr.base and exponent is expr.exponent: return expr return expr.__class__(base, exponent) - def map_left_shift(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: shiftee = self.rec(expr.shiftee, *args, **kwargs) shift = self.rec(expr.shift, *args, **kwargs) if shiftee is expr.shiftee and shift is expr.shift: return expr return type(expr)(shiftee, shift) - map_right_shift = map_left_shift + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + shiftee = self.rec(expr.shiftee, *args, **kwargs) + shift = self.rec(expr.shift, *args, **kwargs) + if shiftee is expr.shiftee and shift is expr.shift: + return expr + return type(expr)(shiftee, shift) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr return type(expr)(child) - def map_bitwise_or(self, expr, *args, **kwargs): + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children)): @@ -568,14 +882,57 @@ class IdentityMapper(Mapper): return type(expr)(tuple(children)) - map_bitwise_xor = map_bitwise_or - map_bitwise_and = map_bitwise_or + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr - map_logical_not = map_bitwise_not - map_logical_or = map_bitwise_or - map_logical_and = map_bitwise_or + return type(expr)(tuple(children)) - def map_comparison(self, expr, *args, **kwargs): + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + child = self.rec(expr.child, *args, **kwargs) + if child is expr.child: + return expr + return type(expr)(child) + + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_comparison(self, + expr: p.Comparison, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: left = self.rec(expr.left, *args, **kwargs) right = self.rec(expr.right, *args, **kwargs) if left is expr.left and right is expr.right: @@ -583,10 +940,16 @@ class IdentityMapper(Mapper): return type(expr)(left, expr.operator, right) - def map_list(self, expr, *args, **kwargs): - return [self.rec(child, *args, **kwargs) for child in expr] + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + + # True fact: lists aren't expressions + return [self.rec(child, *args, **kwargs) for child in expr] # type: ignore[return-value] - def map_tuple(self, expr, *args, **kwargs): + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr] if all(child is orig_child for child, orig_child in zip(children, expr)): @@ -594,21 +957,28 @@ class IdentityMapper(Mapper): return tuple(children) - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + import numpy result = numpy.empty(expr.shape, dtype=object) for i in numpy.ndindex(expr.shape): result[i] = self.rec(expr[i], *args, **kwargs) - return result - def map_multivector(self, expr, *args, **kwargs): + # True fact: ndarrays aren't expressions + return result # type: ignore[return-value] + + def map_multivector(self, + expr: MultiVector[ArithmeticExpressionT], + *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr.map(lambda ch: self.rec(ch, *args, **kwargs)) - def map_common_subexpression(self, expr, *args, **kwargs): - from pymbolic.primitives import is_zero + def map_common_subexpression(self, + expr: p.CommonSubexpression, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: result = self.rec(expr.child, *args, **kwargs) - if is_zero(result): - return 0 if result is expr.child: return expr @@ -618,7 +988,9 @@ class IdentityMapper(Mapper): expr.scope, **expr.get_extra_properties()) - def map_substitution(self, expr, *args, **kwargs): + def map_substitution(self, + expr: p.Substitution, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) values = tuple([self.rec(v, *args, **kwargs) for v in expr.values]) if child is expr.child and all(val is orig_val @@ -627,36 +999,29 @@ class IdentityMapper(Mapper): return type(expr)(child, expr.variables, values) - def map_derivative(self, expr, *args, **kwargs): + def map_derivative(self, + expr: p.Derivative, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr return type(expr)(child, expr.variables) - def map_slice(self, expr, *args, **kwargs): - children = tuple([ + def map_slice(self, + expr: p.Slice, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + children: p.SliceChildrenT = cast(p.SliceChildrenT, tuple([ None if child is None else self.rec(child, *args, **kwargs) for child in expr.children - ]) + ])) if all(child is orig_child for child, orig_child in zip(children, expr.children)): return expr return type(expr)(children) - def map_if_positive(self, expr, *args, **kwargs): - criterion = self.rec(expr.criterion, *args, **kwargs) - then = self.rec(expr.then, *args, **kwargs) - else_ = self.rec(expr.else_, *args, **kwargs) - if criterion is expr.criterion \ - and then is expr.then \ - and else_ is expr.else_: - return expr - - return type(expr)(criterion, then, else_) - - def map_if(self, expr, *args, **kwargs): + def map_if(self, expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: condition = self.rec(expr.condition, *args, **kwargs) then = self.rec(expr.then, *args, **kwargs) else_ = self.rec(expr.else_, *args, **kwargs) @@ -667,7 +1032,7 @@ class IdentityMapper(Mapper): return type(expr)(condition, then, else_) - def map_min(self, expr, *args, **kwargs): + def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: children = tuple([ self.rec(child, *args, **kwargs) for child in expr.children ]) @@ -677,14 +1042,22 @@ class IdentityMapper(Mapper): return type(expr)(children) - map_max = map_min + def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + children = tuple([ + self.rec(child, *args, **kwargs) for child in expr.children + ]) + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(children) - def map_nan(self, expr, *args, **kwargs): + def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: # Leaf node -- don't recurse return expr -class CachedIdentityMapper(CachedMapper, IdentityMapper): +class CachedIdentityMapper(CachedMapper[ExpressionT, P], IdentityMapper[P]): pass # }}} @@ -692,7 +1065,7 @@ class CachedIdentityMapper(CachedMapper, IdentityMapper): # {{{ walk mapper -class WalkMapper(Mapper): +class WalkMapper(Mapper[None, P]): """A mapper whose default mapper method implementations simply recurse without propagating any result. Also calls :meth:`visit` for each visited subexpression. @@ -709,21 +1082,39 @@ class WalkMapper(Mapper): Is called after a node's children are visited. """ - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, expr: object, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_variable(self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_wildcard(self, expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_dot_wildcard(self, + expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs) -> None: self.visit(expr, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - def map_variable(self, expr, *args, **kwargs): + def map_star_wildcard(self, + expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs) -> None: self.visit(expr, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - map_wildcard = map_variable - map_dot_wildcard = map_variable - map_star_wildcard = map_variable - map_function_symbol = map_variable - map_nan = map_variable + def map_function_symbol(self, + expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) - def map_call(self, expr, *args, **kwargs): + def map_nan(self, + expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_call(self, expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -733,7 +1124,9 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -746,7 +1139,9 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, + *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -755,7 +1150,8 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -763,7 +1159,7 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -772,9 +1168,16 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_product = map_sum + def map_product(self, expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return - def map_quotient(self, expr, *args, **kwargs): + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_quotient(self, expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -783,10 +1186,27 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.numerator, *args, **kwargs) + self.rec(expr.denominator, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.numerator, *args, **kwargs) + self.rec(expr.denominator, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) - def map_power(self, expr, *args, **kwargs): + def map_power(self, expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -795,7 +1215,8 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_list(self, expr, *args, **kwargs): + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -804,9 +1225,8 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_tuple = map_list - - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -816,16 +1236,19 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_multivector(self, expr, *args, **kwargs): + def map_multivector(self, + expr: MultiVector[ArithmeticExpressionT], + *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return for _bits, coeff in expr.data.items(): - self.rec(coeff) + self.rec(coeff, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - def map_common_subexpression(self, expr, *args, **kwargs): + def map_common_subexpression(self, + expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -833,7 +1256,8 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_left_shift(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -842,9 +1266,18 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_right_shift = map_left_shift + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.shift, *args, **kwargs) + self.rec(expr.shiftee, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -852,11 +1285,37 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_bitwise_or = map_sum - map_bitwise_xor = map_sum - map_bitwise_and = map_sum + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return - def map_comparison(self, expr, *args, **kwargs): + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_comparison(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -865,11 +1324,36 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_logical_not = map_bitwise_not - map_logical_and = map_sum - map_logical_or = map_sum + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return - def map_if(self, expr, *args, **kwargs): + self.rec(expr.child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_if(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -879,7 +1363,7 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): + def map_if_positive(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -889,11 +1373,28 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - map_min = map_sum - map_max = map_sum + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return - def map_substitution(self, expr, *args, **kwargs): - if not self.visit(expr): + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_substitution(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): return self.rec(expr.child, *args, **kwargs) @@ -902,7 +1403,7 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_derivative(self, expr, *args, **kwargs): + def map_derivative(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -910,7 +1411,7 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def map_slice(self, expr, *args, **kwargs): + def map_slice(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -923,10 +1424,10 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) - def visit(self, expr, *args, **kwargs): + def visit(self, expr, *args: P.args, **kwargs: P.kwargs) -> bool: return True - def post_visit(self, expr, *args, **kwargs): + def post_visit(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: pass @@ -938,6 +1439,7 @@ class CachedWalkMapper(CachedMapper, WalkMapper): # {{{ callback mapper +# FIXME: Is it worth typing this? class CallbackMapper(Mapper): def __init__(self, function, fallback_mapper): self.function = function @@ -984,7 +1486,7 @@ class CallbackMapper(Mapper): # {{{ caching mixins -class CSECachingMapperMixin(ABC): +class CSECachingMapperMixin(ABC, Generic[ResultT, P]): """A :term:`mix-in` that helps subclassed mappers implement caching for :class:`pymbolic.primitives.CommonSubexpression` @@ -999,23 +1501,28 @@ class CSECachingMapperMixin(ABC): This method deliberately does not support extra arguments in mapper dispatch, to avoid spurious dependencies of the cache on these arguments. """ + _cse_cache_dict: dict[tuple[ExpressionT, P.args, P.kwargs], ResultT] - def map_common_subexpression(self, expr, *args): + def map_common_subexpression(self, + expr: p.CommonSubexpression, + *args: P.args, **kwargs: P.kwargs) -> ResultT: try: ccd = self._cse_cache_dict except AttributeError: ccd = self._cse_cache_dict = {} - key = (expr, *args) + key: tuple[ExpressionT, P.args, P.kwargs] = (expr, args, immutabledict(kwargs)) try: return ccd[key] except KeyError: - result = self.map_common_subexpression_uncached(expr, *args) + result = self.map_common_subexpression_uncached(expr, *args, **kwargs) ccd[key] = result return result @abstractmethod - def map_common_subexpression_uncached(self, expr, *args): + def map_common_subexpression_uncached(self, + expr: p.CommonSubexpression, + *args: P.args, **kwargs: P.kwargs) -> ResultT: pass # }}} diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py index 99516f0..54e45d1 100644 --- a/pymbolic/mapper/coefficient.py +++ b/pymbolic/mapper/coefficient.py @@ -1,7 +1,5 @@ from __future__ import annotations -from pymbolic.primitives import flattened_product - __copyright__ = "Copyright (C) 2013 Andreas Kloeckner" @@ -25,17 +23,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from collections.abc import Collection +from typing import Literal, Mapping, TypeAlias, cast + +import pymbolic.primitives as p from pymbolic.mapper import Mapper +from pymbolic.typing import ArithmeticExpressionT + +CoeffsT: TypeAlias = Mapping[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] -class CoefficientCollector(Mapper): - def __init__(self, target_names=None): + +class CoefficientCollector(Mapper[CoeffsT, []]): + def __init__(self, target_names: Collection[str] | None = None) -> None: self.target_names = target_names - def map_sum(self, expr): + def map_sum(self, expr: p.Sum) -> CoeffsT: stride_dicts = [self.rec(ch) for ch in expr.children] - result = {} + result: dict[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] = {} for stride_dict in stride_dicts: for var, stride in stride_dict.items(): if var in result: @@ -45,9 +51,7 @@ class CoefficientCollector(Mapper): return result - def map_product(self, expr): - result = {} - + def map_product(self, expr: p.Product) -> CoeffsT: children_coeffs = [self.rec(child) for child in expr.children] idx_of_child_with_vars = None @@ -60,35 +64,33 @@ class CoefficientCollector(Mapper): "nonlinear expression") idx_of_child_with_vars = i - other_coeffs = 1 + other_coeffs: ArithmeticExpressionT = 1 for i, child_coeffs in enumerate(children_coeffs): if i != idx_of_child_with_vars: assert len(child_coeffs) == 1 - other_coeffs *= child_coeffs[1] + other_coeffs *= cast(ArithmeticExpressionT, child_coeffs[1]) if idx_of_child_with_vars is None: return {1: other_coeffs} else: return { - var: flattened_product((other_coeffs, coeff)) + var: p.flattened_product((other_coeffs, coeff)) for var, coeff in children_coeffs[idx_of_child_with_vars].items()} - return result - - def map_quotient(self, expr): + def map_quotient(self, expr: p.Quotient) -> CoeffsT: from pymbolic.primitives import Quotient - d_num = self.rec(expr.numerator) + d_num = dict(self.rec(expr.numerator)) d_den = self.rec(expr.denominator) # d_den should look like {1: k} if len(d_den) > 1 or 1 not in d_den: raise RuntimeError("nonlinear expression") val = d_den[1] for k in d_num.keys(): - d_num[k] = flattened_product((d_num[k], Quotient(1, val))) + d_num[k] = p.flattened_product((d_num[k], Quotient(1, val))) return d_num - def map_power(self, expr): + def map_power(self, expr: p.Power) -> CoeffsT: d_base = self.rec(expr.base) d_exponent = self.rec(expr.exponent) # d_exponent should look like {1: k} @@ -99,11 +101,19 @@ class CoefficientCollector(Mapper): raise RuntimeError("nonlinear expression") return {1: expr} - def map_constant(self, expr): - return {1: expr} + def map_constant(self, expr: object) -> CoeffsT: + assert p.is_arithmetic_expression(expr) + from pymbolic.primitives import is_zero + return {} if is_zero(expr) else {1: expr} - def map_algebraic_leaf(self, expr): + def map_variable(self, expr: p.Variable) -> CoeffsT: if self.target_names is None or expr.name in self.target_names: return {expr: 1} else: return {1: expr} + + def map_algebraic_leaf(self, expr: p.AlgebraicLeaf) -> CoeffsT: + if self.target_names is None: + return {expr: 1} + else: + return {1: expr} diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py index 2799cba..e0c80db 100644 --- a/pymbolic/mapper/collector.py +++ b/pymbolic/mapper/collector.py @@ -26,11 +26,16 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import AbstractSet, Sequence, cast + import pymbolic +import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper +from pymbolic.mapper.dependency import DependenciesT +from pymbolic.typing import ArithmeticExpressionT, ExpressionT -class TermCollector(IdentityMapper): +class TermCollector(IdentityMapper[[]]): """A term collector that assumes that multiplication is commutative. Allows specifying *parameters* (a set of @@ -38,16 +43,19 @@ class TermCollector(IdentityMapper): coefficients and are not used for term collection. """ - def __init__(self, parameters=None): + def __init__(self, parameters: AbstractSet[p.AlgebraicLeaf] | None = None): if parameters is None: parameters = set() self.parameters = parameters - def get_dependencies(self, expr): + def get_dependencies(self, expr: ExpressionT) -> DependenciesT: from pymbolic.mapper.dependency import DependencyMapper return DependencyMapper()(expr) - def split_term(self, mul_term): + def split_term(self, mul_term: ExpressionT) -> tuple[ + AbstractSet[tuple[ArithmeticExpressionT, ArithmeticExpressionT]], + ArithmeticExpressionT + ]: """Returns a pair consisting of: - a frozenset of (base, exponent) pairs - a product of coefficients (i.e. constants and parameters) @@ -58,20 +66,21 @@ class TermCollector(IdentityMapper): """ from pymbolic.primitives import AlgebraicLeaf, Power, Product - def base(term): + def base(term: ExpressionT) -> ArithmeticExpressionT: if isinstance(term, Power): return term.base else: + assert p.is_arithmetic_expression(term) return term - def exponent(term): + def exponent(term: ExpressionT) -> ArithmeticExpressionT: if isinstance(term, Power): return term.exponent else: return 1 if isinstance(mul_term, Product): - terms = mul_term.children + terms: Sequence[ExpressionT] = mul_term.children elif isinstance(mul_term, (Power, AlgebraicLeaf)): terms = [mul_term] elif not bool(self.get_dependencies(mul_term)): @@ -79,7 +88,7 @@ class TermCollector(IdentityMapper): else: raise RuntimeError("split_term expects a multiplicative term") - base2exp = {} + base2exp: dict[ArithmeticExpressionT, ArithmeticExpressionT] = {} for term in terms: mybase = base(term) myexp = exponent(term) @@ -91,20 +100,23 @@ class TermCollector(IdentityMapper): coefficients = [] cleaned_base2exp = {} - for base, exp in base2exp.items(): - term = base**exp + for item_base, item_exp in base2exp.items(): + term = item_base**item_exp if self.get_dependencies(term) <= self.parameters: coefficients.append(term) else: - cleaned_base2exp[base] = exp + cleaned_base2exp[item_base] = item_exp - term = frozenset( + base_exp_set = frozenset( (base, exp) for base, exp in cleaned_base2exp.items()) - return term, self.rec(pymbolic.flattened_product(coefficients)) - - def map_sum(self, mysum): - term2coeff = {} - for child in mysum.children: + return base_exp_set, cast(ArithmeticExpressionT, + self.rec(pymbolic.flattened_product(coefficients))) + + def map_sum(self, expr: p.Sum) -> ExpressionT: + term2coeff: dict[ + AbstractSet[tuple[ArithmeticExpressionT, ArithmeticExpressionT]], + ArithmeticExpressionT] = {} + for child in expr.children: term, coeff = self.split_term(child) term2coeff[term] = term2coeff.get(term, 0) + coeff diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py index 62b3edd..9bf31b8 100644 --- a/pymbolic/mapper/constant_folder.py +++ b/pymbolic/mapper/constant_folder.py @@ -27,13 +27,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from collections.abc import Callable + from pymbolic.mapper import ( CSECachingMapperMixin, IdentityMapper, + Mapper, ) +from pymbolic.primitives import Product, Sum, is_arithmetic_expression +from pymbolic.typing import ArithmeticExpressionT, ExpressionT -class ConstantFoldingMapperBase: +class ConstantFoldingMapperBase(Mapper[ExpressionT, []]): def is_constant(self, expr): from pymbolic.mapper.dependency import DependencyMapper return not bool(DependencyMapper()(expr)) @@ -45,15 +50,27 @@ class ConstantFoldingMapperBase: except ValueError: return None - def fold(self, expr, klass, op, constructor): + def fold(self, + expr: Sum | Product, + op: Callable[ + [ArithmeticExpressionT, ArithmeticExpressionT], + ArithmeticExpressionT], + constructor: Callable[ + [tuple[ArithmeticExpressionT, ...]], + ArithmeticExpressionT], + ) -> ExpressionT: + klass = type(expr) - constants = [] - nonconstants = [] + constants: list[ArithmeticExpressionT] = [] + nonconstants: list[ArithmeticExpressionT] = [] queue = list(expr.children) while queue: - child = self.rec(queue.pop(0)) # pylint:disable=no-member + child = self.rec(queue.pop(0)) + assert is_arithmetic_expression(child) + if isinstance(child, klass): + assert isinstance(child, (Sum, Product)) queue = list(child.children) + queue else: if self.is_constant(child): @@ -73,37 +90,36 @@ class ConstantFoldingMapperBase: else: return constructor(tuple(nonconstants)) - def map_sum(self, expr): + def map_sum(self, expr: Sum) -> ExpressionT: import operator - from pymbolic.primitives import Sum, flattened_sum + from pymbolic.primitives import flattened_sum - return self.fold(expr, Sum, operator.add, flattened_sum) + return self.fold(expr, operator.add, flattened_sum) class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): def map_product(self, expr): import operator - from pymbolic.primitives import Product, flattened_product + from pymbolic.primitives import flattened_product - return self.fold(expr, Product, operator.mul, flattened_product) + return self.fold(expr, operator.mul, flattened_product) class ConstantFoldingMapper( - CSECachingMapperMixin, + CSECachingMapperMixin[ExpressionT, []], ConstantFoldingMapperBase, - IdentityMapper): + IdentityMapper[[]]): map_common_subexpression_uncached = \ IdentityMapper.map_common_subexpression -# Yes, map_product incompatible: missing *args, **kwargs -class CommutativeConstantFoldingMapper( # type: ignore[misc] - CSECachingMapperMixin, +class CommutativeConstantFoldingMapper( + CSECachingMapperMixin[ExpressionT, []], CommutativeConstantFoldingMapperBase, - IdentityMapper): + IdentityMapper[[]]): map_common_subexpression_uncached = \ IdentityMapper.map_common_subexpression diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index b4473e3..a4e5b0f 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -2,6 +2,7 @@ .. autoclass:: DependencyMapper .. autoclass:: CachedDependencyMapper """ + from __future__ import annotations @@ -27,10 +28,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin +from typing import AbstractSet + +from typing_extensions import TypeAlias + +import pymbolic.primitives as p +from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin, P + +DependenciesT: TypeAlias = AbstractSet[p.AlgebraicLeaf | p.CommonSubexpression] -class DependencyMapper(CSECachingMapperMixin, Collector): + +class DependencyMapper( + CSECachingMapperMixin[DependenciesT, P], + Collector[p.AlgebraicLeaf | p.CommonSubexpression, P], +): """Maps an expression to the :class:`set` of expressions it is based on. The ``include_*`` arguments to the constructor determine which types of objects occur in this output set. @@ -38,12 +50,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector): instances are included. """ - def __init__(self, - include_subscripts=True, - include_lookups=True, - include_calls=True, - include_cses=False, - composite_leaves=None): + def __init__( + self, + include_subscripts: bool = True, + include_lookups: bool = True, + include_calls: bool = True, + include_cses: bool = False, + composite_leaves: bool | None = None, + ): """ :arg composite_leaves: Setting this is equivalent to setting all preceding ``include_*`` flags. @@ -66,68 +80,92 @@ class DependencyMapper(CSECachingMapperMixin, Collector): self.include_cses = include_cses - def map_variable(self, expr, *args, **kwargs): + def map_variable( + self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: return {expr} - def map_call(self, expr, *args, **kwargs): + def map_call( + self, expr: p.Call, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_calls == "descend_args": - return self.combine( - [self.rec(child, *args, **kwargs) for child in expr.parameters]) + return self.combine([ + self.rec(child, *args, **kwargs) for child in expr.parameters + ]) elif self.include_calls: return {expr} else: return super().map_call(expr, *args, **kwargs) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs( + self, expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_calls == "descend_args": return self.combine( - [self.rec(child, *args, **kwargs) for child in expr.parameters] - + [self.rec(val, *args, **kwargs) for name, val in - expr.kw_parameters.items()] - ) + [self.rec(child, *args, **kwargs) for child in expr.parameters] + + [ + self.rec(val, *args, **kwargs) + for name, val in expr.kw_parameters.items() + ] + ) elif self.include_calls: return {expr} else: return super().map_call_with_kwargs(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup( + self, expr: p.Lookup, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_lookups: return {expr} else: return super().map_lookup(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript( + self, expr: p.Subscript, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_subscripts: return {expr} else: return super().map_subscript(expr, *args, **kwargs) - def map_common_subexpression_uncached(self, expr, *args, **kwargs): + def map_common_subexpression_uncached( + self, expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_cses: return {expr} else: - return Collector.map_common_subexpression(self, expr, *args, **kwargs) - - def map_slice(self, expr, *args, **kwargs): - return self.combine( - [self.rec(child, *args, **kwargs) for child in expr.children - if child is not None]) - - def map_nan(self, expr, *args, **kwargs): + # FIXME: These look like mypy bugs, revisit + return Collector.map_common_subexpression(self, expr, *args, **kwargs) # type: ignore[return-value, arg-type] + + def map_slice( + self, expr: p.Slice, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: + return self.combine([ + self.rec(child, *args, **kwargs) + for child in expr.children + if child is not None + ]) + + def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> DependenciesT: return set() class CachedDependencyMapper(CachedMapper, DependencyMapper): - def __init__(self, - include_subscripts=True, - include_lookups=True, - include_calls=True, - include_cses=False, - composite_leaves=None): + def __init__( + self, + include_subscripts=True, + include_lookups=True, + include_calls=True, + include_cses=False, + composite_leaves=None, + ): CachedMapper.__init__(self) - DependencyMapper.__init__(self, - include_subscripts=include_subscripts, - include_lookups=include_lookups, - include_calls=include_calls, - include_cses=include_cses, - composite_leaves=composite_leaves) + DependencyMapper.__init__( + self, + include_subscripts=include_subscripts, + include_lookups=include_lookups, + include_calls=include_calls, + include_cses=include_cses, + composite_leaves=composite_leaves, + ) diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py index faa523e..85d1350 100644 --- a/pymbolic/mapper/distributor.py +++ b/pymbolic/mapper/distributor.py @@ -27,14 +27,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import cast + import pymbolic +import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper from pymbolic.mapper.collector import TermCollector from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper -from pymbolic.primitives import Product, Sum, is_zero +from pymbolic.typing import ArithmeticExpressionT, ExpressionT -class DistributeMapper(IdentityMapper): +class DistributeMapper(IdentityMapper[[]]): """Example usage: .. doctest:: @@ -47,7 +50,7 @@ class DistributeMapper(IdentityMapper): 7*x**6 + 21*x**5 + 21*x**2 + 35*x**3 + 1 + 35*x**4 + 7*x + x**7 """ - def __init__(self, collector=None, const_folder=None): + def __init__(self, collector=None, const_folder=None) -> None: if collector is None: collector = TermCollector() if const_folder is None: @@ -61,19 +64,19 @@ class DistributeMapper(IdentityMapper): def map_sum(self, expr): res = IdentityMapper.map_sum(self, expr) - if isinstance(res, Sum): + if isinstance(res, p.Sum): return self.collect(res) else: return res - def map_product(self, expr): + def map_product(self, expr: p.Product) -> ExpressionT: def dist(prod): - if not isinstance(prod, Product): + if not isinstance(prod, p.Product): return prod leading = [] for i in prod.children: - if isinstance(i, Sum): + if isinstance(i, p.Sum): break else: leading.append(i) @@ -84,10 +87,10 @@ class DistributeMapper(IdentityMapper): return result else: sum = prod.children[len(leading)] - assert isinstance(sum, Sum) + assert isinstance(sum, p.Sum) rest = prod.children[len(leading)+1:] if rest: - rest = dist(Product(rest)) + rest = dist(p.Product(rest)) else: rest = 1 @@ -100,7 +103,7 @@ class DistributeMapper(IdentityMapper): return dist(IdentityMapper.map_product(self, expr)) def map_quotient(self, expr): - if is_zero(expr.numerator - 1): + if p.is_zero(expr.numerator - 1): return expr else: # not the smartest thing we can do, but at least *something* @@ -109,18 +112,19 @@ class DistributeMapper(IdentityMapper): self.rec(expr.numerator) ]) - def map_power(self, expr): + def map_power(self, expr: p.Power) -> ExpressionT: from pymbolic.primitives import Sum newbase = self.rec(expr.base) - if isinstance(expr.base, Product): + if isinstance(newbase, p.Product): return self.rec(pymbolic.flattened_product([ - child**expr.exponent for child in newbase + cast(ArithmeticExpressionT, child)**expr.exponent + for child in newbase.children ])) if isinstance(expr.exponent, int): if isinstance(newbase, Sum): - return self.map_product( + return self.rec( pymbolic.flattened_product( expr.exponent*(newbase,))) else: @@ -129,7 +133,7 @@ class DistributeMapper(IdentityMapper): return IdentityMapper.map_power(self, expr) -def distribute(expr, parameters=None, commutative=True): +def distribute(expr: ExpressionT, parameters=None, commutative=True) -> ExpressionT: if parameters is None: parameters = frozenset() if commutative: diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index 44e9966..67b7334 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -33,19 +33,27 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import operator as op +from collections.abc import Mapping from functools import reduce -from typing import Any +from typing import TYPE_CHECKING, Any +import pymbolic.primitives as p from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper +from pymbolic.typing import ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.geometric_algebra import MultiVector class UnknownVariableError(Exception): pass -class EvaluationMapper(Mapper, CSECachingMapperMixin): +class EvaluationMapper(Mapper[Any, []], CSECachingMapperMixin): """Example usage: .. doctest:: @@ -62,7 +70,9 @@ class EvaluationMapper(Mapper, CSECachingMapperMixin): 110 """ - def __init__(self, context=None): + context: Mapping[str, Any] + + def __init__(self, context: Mapping[str, Any] | None = None) -> None: """ :arg context: a mapping from variable names to values """ @@ -70,21 +80,20 @@ class EvaluationMapper(Mapper, CSECachingMapperMixin): context = {} self.context = context - self.common_subexp_cache = {} - def map_constant(self, expr): + def map_constant(self, expr: object) -> Any: return expr - def map_variable(self, expr): + def map_variable(self, expr: p.Variable) -> None: try: return self.context[expr.name] except KeyError: raise UnknownVariableError(expr.name) from None - def map_call(self, expr): + def map_call(self, expr: p.Call) -> Any: return self.rec(expr.function)(*[self.rec(par) for par in expr.parameters]) - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs(self, expr: p.CallWithKwargs) -> Any: args = [self.rec(par) for par in expr.parameters] kwargs = { k: self.rec(v) @@ -92,109 +101,97 @@ class EvaluationMapper(Mapper, CSECachingMapperMixin): return self.rec(expr.function)(*args, **kwargs) - def map_subscript(self, expr): - rec_result = self.rec(expr.aggregate) + def map_subscript(self, expr: p.Subscript) -> Any: + return self.rec(expr.aggregate)[self.rec(expr.index)] - from pymbolic.primitives import Expression - if isinstance(rec_result, Expression): - return rec_result.index(self.rec(expr.index)) - else: - return rec_result[self.rec(expr.index)] - - def map_lookup(self, expr): + def map_lookup(self, expr: p.Lookup) -> Any: return getattr(self.rec(expr.aggregate), expr.name) - def map_sum(self, expr): + def map_sum(self, expr: p.Sum) -> Any: return sum(self.rec(child) for child in expr.children) - def map_product(self, expr): + def map_product(self, expr: p.Product) -> Any: from pytools import product return product(self.rec(child) for child in expr.children) - def map_quotient(self, expr): + def map_quotient(self, expr: p.Quotient) -> Any: return self.rec(expr.numerator) / self.rec(expr.denominator) - def map_floor_div(self, expr): + def map_floor_div(self, expr: p.FloorDiv) -> Any: return self.rec(expr.numerator) // self.rec(expr.denominator) - def map_remainder(self, expr): + def map_remainder(self, expr: p.Remainder) -> Any: return self.rec(expr.numerator) % self.rec(expr.denominator) - def map_power(self, expr): + def map_power(self, expr: p.Power) -> Any: return self.rec(expr.base) ** self.rec(expr.exponent) - def map_left_shift(self, expr): + def map_left_shift(self, expr: p.LeftShift) -> Any: return self.rec(expr.shiftee) << self.rec(expr.shift) - def map_right_shift(self, expr): + def map_right_shift(self, expr: p.RightShift) -> Any: return self.rec(expr.shiftee) >> self.rec(expr.shift) - def map_bitwise_not(self, expr): + def map_bitwise_not(self, expr: p.BitwiseNot) -> Any: # ??? Why, pylint, why ??? # pylint: disable=invalid-unary-operand-type return ~self.rec(expr.child) - def map_bitwise_or(self, expr): + def map_bitwise_or(self, expr: p.BitwiseOr) -> Any: return reduce(op.or_, (self.rec(ch) for ch in expr.children)) - def map_bitwise_xor(self, expr): + def map_bitwise_xor(self, expr: p.BitwiseXor) -> Any: return reduce(op.xor, (self.rec(ch) for ch in expr.children)) - def map_bitwise_and(self, expr): + def map_bitwise_and(self, expr: p.BitwiseAnd) -> Any: return reduce(op.and_, (self.rec(ch) for ch in expr.children)) - def map_logical_not(self, expr): + def map_logical_not(self, expr: p.LogicalNot) -> Any: return not self.rec(expr.child) - def map_logical_or(self, expr): + def map_logical_or(self, expr: p.LogicalOr) -> Any: return any(self.rec(ch) for ch in expr.children) - def map_logical_and(self, expr): + def map_logical_and(self, expr: p.LogicalAnd) -> Any: return all(self.rec(ch) for ch in expr.children) - def map_list(self, expr): + def map_list(self, expr: list[ExpressionT]) -> Any: return [self.rec(child) for child in expr] - def map_numpy_array(self, expr): + def map_numpy_array(self, expr: np.ndarray) -> Any: import numpy result = numpy.empty(expr.shape, dtype=object) for i in numpy.ndindex(expr.shape): result[i] = self.rec(expr[i]) return result - def map_multivector(self, expr, *args): - return expr.map(lambda ch: self.rec(ch, *args)) + def map_multivector(self, expr: MultiVector) -> Any: + return expr.map(lambda ch: self.rec(ch)) - def map_common_subexpression_uncached(self, expr): + def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> Any: return self.rec(expr.child) - def map_if_positive(self, expr): - if self.rec(expr.criterion) > 0: + def map_if(self, expr: p.If) -> Any: + if self.rec(expr.condition): return self.rec(expr.then) else: return self.rec(expr.else_) - def map_comparison(self, expr): + def map_comparison(self, expr: p.Comparison) -> Any: import operator return getattr(operator, expr.operator_to_name[expr.operator])( self.rec(expr.left), self.rec(expr.right)) - def map_if(self, expr): - if self.rec(expr.condition): - return self.rec(expr.then) - else: - return self.rec(expr.else_) - - def map_min(self, expr): + def map_min(self, expr: p.Min) -> Any: return min(self.rec(child) for child in expr.children) - def map_max(self, expr): + def map_max(self, expr: p.Max) -> Any: return max(self.rec(child) for child in expr.children) - def map_tuple(self, expr): + def map_tuple(self, expr: tuple[ExpressionT, ...]) -> Any: return tuple([self.rec(child) for child in expr]) - def map_nan(self, expr): + def map_nan(self, expr: p.NaN) -> Any: if expr.data_type is None: from math import nan return nan diff --git a/pymbolic/mapper/flattener.py b/pymbolic/mapper/flattener.py index 4041206..121cf65 100644 --- a/pymbolic/mapper/flattener.py +++ b/pymbolic/mapper/flattener.py @@ -31,10 +31,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import cast import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper -from pymbolic.typing import ExpressionT +from pymbolic.typing import ArithmeticExpressionT, ArithmeticOrExpressionT, ExpressionT class FlattenMapper(IdentityMapper[[]]): @@ -50,16 +51,19 @@ class FlattenMapper(IdentityMapper[[]]): """ def map_sum(self, expr: p.Sum) -> ExpressionT: from pymbolic.primitives import flattened_sum - return flattened_sum([self.rec(ch) for ch in expr.children]) + return flattened_sum([ + cast(ArithmeticExpressionT, self.rec(ch)) + for ch in expr.children]) def map_product(self, expr: p.Product) -> ExpressionT: from pymbolic.primitives import flattened_product - return flattened_product([self.rec(ch) for ch in expr.children]) + return flattened_product([ + cast(ArithmeticExpressionT, self.rec(ch)) + for ch in expr.children]) def map_quotient(self, expr: p.Quotient) -> ExpressionT: - r_num = self.rec(expr.numerator) - r_den = self.rec(expr.denominator) - assert p.is_arithmetic_expression(r_den) + r_num = self.rec_arith(expr.numerator) + r_den = self.rec_arith(expr.denominator) if p.is_zero(r_num): return 0 if p.is_zero(r_den - 1): @@ -68,9 +72,8 @@ class FlattenMapper(IdentityMapper[[]]): return expr.__class__(r_num, r_den) def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT: - r_num = self.rec(expr.numerator) - r_den = self.rec(expr.denominator) - assert p.is_arithmetic_expression(r_den) + r_num = self.rec_arith(expr.numerator) + r_den = self.rec_arith(expr.denominator) if p.is_zero(r_num): return 0 if p.is_zero(r_den - 1): @@ -79,8 +82,8 @@ class FlattenMapper(IdentityMapper[[]]): return expr.__class__(r_num, r_den) def map_remainder(self, expr: p.Remainder) -> ExpressionT: - r_num = self.rec(expr.numerator) - r_den = self.rec(expr.denominator) + r_num = self.rec_arith(expr.numerator) + r_den = self.rec_arith(expr.denominator) assert p.is_arithmetic_expression(r_den) if p.is_zero(r_num): return 0 @@ -90,10 +93,8 @@ class FlattenMapper(IdentityMapper[[]]): return expr.__class__(r_num, r_den) def map_power(self, expr: p.Power) -> ExpressionT: - r_base = self.rec(expr.base) - r_exp = self.rec(expr.exponent) - - assert p.is_arithmetic_expression(r_exp) + r_base = self.rec_arith(expr.base) + r_exp = self.rec_arith(expr.exponent) if p.is_zero(r_exp - 1): return r_base @@ -101,5 +102,5 @@ class FlattenMapper(IdentityMapper[[]]): return expr.__class__(r_base, r_exp) -def flatten(expr): - return FlattenMapper()(expr) +def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT: + return cast(ArithmeticOrExpressionT, FlattenMapper()(expr)) diff --git a/pymbolic/mapper/optimize.py b/pymbolic/mapper/optimize.py index b7c848b..14a2987 100644 --- a/pymbolic/mapper/optimize.py +++ b/pymbolic/mapper/optimize.py @@ -129,15 +129,15 @@ class _RecInliner(ast.NodeTransformer): self.inline_rec = inline_rec self.inline_cache = inline_cache - def visit_Call(self, node): # noqa: N802 - node = self.generic_visit(node) + def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802 + node = cast(ast.Call, self.generic_visit(node)) - result_expr = node + result_expr: ast.expr = node if (isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id == "self" - and node.func.attr == "rec"): + and node.func.attr in ["rec", "rec_arith"]): from ast import ( Attribute, @@ -191,7 +191,7 @@ class _RecInliner(ast.NodeTransformer): args=[expr], keywords=[]) cache_key_expr = ast.Tuple([expr_type, expr], ctx=Load()) - nic = Name(id="_NOT_IN_CACHE", ctx=Load()) + nic = Name(id="_NotInCache", ctx=Load()) result_expr = IfExp( test=Compare( diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 4fc08e1..46ba83f 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -22,11 +22,21 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from collections.abc import Sequence +from typing import TYPE_CHECKING, ClassVar, Concatenate +from warnings import warn -from typing import ClassVar +from typing_extensions import deprecated import pymbolic.primitives as p -from pymbolic.mapper import CachedMapper, Mapper +from pymbolic.mapper import CachedMapper, Mapper, P +from pymbolic.typing import ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.geometric_algebra import MultiVector __doc__ = """ @@ -81,7 +91,8 @@ PREC_NONE = 0 # {{{ stringifier -class StringifyMapper(Mapper): + +class StringifyMapper(Mapper[str, Concatenate[int, P]]): """A mapper to turn an expression tree into a string. :class:`pymbolic.Expression.__str__` is often implemented using @@ -94,13 +105,25 @@ class StringifyMapper(Mapper): # {{{ replaceable string composition interface - def format(self, s, *args): + def format(self, s: str, *args: object) -> str: return s % args - def join(self, joiner, iterable): - return self.format(joiner.join("%s" for _ in iterable), *iterable) + def join(self, joiner: str, seq: Sequence[ExpressionT]) -> str: + return self.format(joiner.join("%s" for _ in seq), *seq) + + # {{{ deprecated junk + @deprecated("interface not type-safe, use rec_with_parens_around_types") def rec_with_force_parens_around(self, expr, *args, **kwargs): + warn( + "rec_with_force_parens_around is deprecated and will be removed in 2025. " + "Use rec_with_parens_around_types instead. ", + DeprecationWarning, + stacklevel=2, + ) + # Not currently possible to make this type-safe: + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters + force_parens_around = kwargs.pop("force_parens_around", ()) result = self.rec(expr, *args, **kwargs) @@ -110,16 +133,77 @@ class StringifyMapper(Mapper): return result - def join_rec(self, joiner, iterable, prec, *args, **kwargs): - f = joiner.join("%s" for _ in iterable) - return self.format(f, - *[self.rec_with_force_parens_around(i, prec, *args, **kwargs) - for i in iterable]) + def join_rec( + self, + joiner: str, + seq: Sequence[ExpressionT], + prec: int, + *args, + **kwargs, # force_with_parens_around may hide in here + ) -> str: + f = joiner.join("%s" for _ in seq) + + if "force_parens_around" in kwargs: + warn( + "Passing force_parens_around join_rec is deprecated and will be " + "removed in 2025. " + "Use join_rec_with_parens_around_types instead. ", + DeprecationWarning, + stacklevel=2, + ) + # Not currently possible to make this type-safe: + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters + parens_around_types: tuple[type, ...] = kwargs.pop("force_parens_around") + return self.join_rec_with_parens_around_types( + joiner, seq, prec, parens_around_types, *args, **kwargs + ) + + return self.format( + f, + *[self.rec(i, prec, *args, **kwargs) for i in seq], + ) + + # }}} + + def rec_with_parens_around_types( + self, + expr: ExpressionT, + enclosing_prec: int, + parens_around: tuple[type, ...], + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + result = self.rec(expr, enclosing_prec, *args, **kwargs) + + if isinstance(expr, parens_around): + result = f"({result})" - def parenthesize(self, s): + return result + + def join_rec_with_parens_around_types( + self, + joiner: str, + seq: Sequence[ExpressionT], + prec: int, + parens_around_types: tuple[type, ...], + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + f = joiner.join("%s" for _ in seq) + return self.format( + f, + *[ + self.rec_with_parens_around_types( + i, prec, parens_around_types, *args, **kwargs + ) + for i in seq + ], + ) + + def parenthesize(self, s: str) -> str: return f"({s})" - def parenthesize_if_needed(self, s, enclosing_prec, my_prec): + def parenthesize_if_needed(self, s: str, enclosing_prec: int, my_prec: int) -> str: if enclosing_prec > my_prec: return f"({s})" else: @@ -129,210 +213,391 @@ class StringifyMapper(Mapper): # {{{ mappings - def handle_unsupported_expression(self, expr, enclosing_prec, *args, **kwargs): + def handle_unsupported_expression( + self, expr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: strifier = expr.make_stringifier(self) if isinstance(self, type(strifier)): - raise ValueError( - f"stringifier '{self}' can't handle '{expr.__class__}'") - return strifier( - expr, enclosing_prec, *args, **kwargs) + raise ValueError(f"stringifier '{self}' can't handle '{expr.__class__}'") + return strifier(expr, enclosing_prec, *args, **kwargs) - def map_constant(self, expr, enclosing_prec, *args, **kwargs): + def map_constant( + self, expr: object, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: result = str(expr) - if not (result.startswith("(") and result.endswith(")")) \ - and ("-" in result or "+" in result) \ - and (enclosing_prec > PREC_SUM): + if ( + not (result.startswith("(") and result.endswith(")")) + and ("-" in result or "+" in result) + and (enclosing_prec > PREC_SUM) + ): return self.parenthesize(result) else: return result - def map_variable(self, expr, enclosing_prec, *args, **kwargs): + def map_variable( + self, expr: p.Variable, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return expr.name - def map_wildcard(self, expr, enclosing_prec, *args, **kwargs): + def map_wildcard( + self, expr: p.Wildcard, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return "*" - def map_function_symbol(self, expr, enclosing_prec, *args, **kwargs): + def map_function_symbol( + self, + expr: p.FunctionSymbol, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return expr.__class__.__name__ - def map_call(self, expr, enclosing_prec, *args, **kwargs): - return self.format("%s(%s)", - self.rec(expr.function, PREC_CALL, *args, **kwargs), - self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs)) - - def map_call_with_kwargs(self, expr, enclosing_prec, *args, **kwargs): - args_strings = ( - tuple([ - self.rec(ch, PREC_NONE, *args, **kwargs) - for ch in expr.parameters - ]) - + - tuple([ - "{}={}".format(name, self.rec(ch, PREC_NONE, *args, **kwargs)) - for name, ch in expr.kw_parameters.items() - ]) - ) - return self.format("%s(%s)", - self.rec(expr.function, PREC_CALL, *args, **kwargs), - ", ".join(args_strings)) - - def map_subscript(self, expr, enclosing_prec, *args, **kwargs): + def map_call( + self, expr: p.Call, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.format( + "%s(%s)", + self.rec(expr.function, PREC_CALL, *args, **kwargs), + self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs), + ) + + def map_call_with_kwargs( + self, + expr: p.CallWithKwargs, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + args_strings = tuple([ + self.rec(ch, PREC_NONE, *args, **kwargs) for ch in expr.parameters + ]) + tuple([ + "{}={}".format(name, self.rec(ch, PREC_NONE, *args, **kwargs)) + for name, ch in expr.kw_parameters.items() + ]) + return self.format( + "%s(%s)", + self.rec(expr.function, PREC_CALL, *args, **kwargs), + ", ".join(args_strings), + ) + + def map_subscript( + self, expr: p.Subscript, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: if isinstance(expr.index, tuple): index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs) else: index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs) return self.parenthesize_if_needed( - self.format("%s[%s]", - self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), - index_str), - enclosing_prec, PREC_CALL) - - def map_lookup(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s[%s]", + self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), + index_str, + ), + enclosing_prec, + PREC_CALL, + ) + + def map_lookup( + self, expr: p.Lookup, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s.%s", - self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), - expr.name), - enclosing_prec, PREC_CALL) - - def map_sum(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s.%s", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), expr.name + ), + enclosing_prec, + PREC_CALL, + ) + + def map_sum( + self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), - enclosing_prec, PREC_SUM) + self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, + PREC_SUM, + ) # {{{ multiplicative operators multiplicative_primitives = (p.Product, p.Quotient, p.FloorDiv, p.Remainder) - def map_product(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = (p.Quotient, p.FloorDiv, p.Remainder) + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs), - enclosing_prec, PREC_PRODUCT) - - def map_quotient(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = self.multiplicative_primitives + self.join_rec_with_parens_around_types( + "*", + expr.children, + PREC_PRODUCT, + (p.Quotient, p.FloorDiv, p.Remainder), + *args, + **kwargs, + ), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_quotient( + self, expr: p.Quotient, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s / %s", - # space is necessary--otherwise '/*' becomes - # start-of-comment in C. ('*' from dereference) - self.rec_with_force_parens_around(expr.numerator, PREC_PRODUCT, - *args, **kwargs), - self.rec_with_force_parens_around( - expr.denominator, PREC_PRODUCT, *args, **kwargs)), - enclosing_prec, PREC_PRODUCT) - - def map_floor_div(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = self.multiplicative_primitives + self.format( + "%s / %s", + # space is necessary--otherwise '/*' becomes + # start-of-comment in C. ('*' from dereference) + self.rec_with_parens_around_types( + expr.numerator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + self.rec_with_parens_around_types( + expr.denominator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + ), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_floor_div( + self, expr: p.FloorDiv, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s // %s", - self.rec_with_force_parens_around( - expr.numerator, PREC_PRODUCT, *args, **kwargs), - self.rec_with_force_parens_around( - expr.denominator, PREC_PRODUCT, *args, **kwargs)), - enclosing_prec, PREC_PRODUCT) - - def map_remainder(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = self.multiplicative_primitives + self.format( + "%s // %s", + self.rec_with_parens_around_types( + expr.numerator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + self.rec_with_parens_around_types( + expr.denominator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + ), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_remainder( + self, expr: p.Remainder, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s %% %s", - self.rec_with_force_parens_around( - expr.numerator, PREC_PRODUCT, *args, **kwargs), - self.rec_with_force_parens_around( - expr.denominator, PREC_PRODUCT, *args, **kwargs)), - enclosing_prec, PREC_PRODUCT) + self.format( + "%s %% %s", + self.rec_with_parens_around_types( + expr.numerator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + self.rec_with_parens_around_types( + expr.denominator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + ), + enclosing_prec, + PREC_PRODUCT, + ) # }}} - def map_power(self, expr, enclosing_prec, *args, **kwargs): + def map_power( + self, expr: p.Power, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s**%s", - self.rec(expr.base, PREC_POWER, *args, **kwargs), - self.rec(expr.exponent, PREC_POWER, *args, **kwargs)), - enclosing_prec, PREC_POWER) - - def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s**%s", + self.rec(expr.base, PREC_POWER, *args, **kwargs), + self.rec(expr.exponent, PREC_POWER, *args, **kwargs), + ), + enclosing_prec, + PREC_POWER, + ) + + def map_left_shift( + self, expr: p.LeftShift, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - # +1 to address - # https://gitlab.tiker.net/inducer/pymbolic/issues/6 - self.format("%s << %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): + # +1 to address + # https://gitlab.tiker.net/inducer/pymbolic/issues/6 + self.format( + "%s << %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_right_shift( + self, + expr: p.RightShift, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - # +1 to address - # https://gitlab.tiker.net/inducer/pymbolic/issues/6 - self.format("%s >> %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_bitwise_not(self, expr, enclosing_prec, *args, **kwargs): + # +1 to address + # https://gitlab.tiker.net/inducer/pymbolic/issues/6 + self.format( + "%s >> %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_bitwise_not( + self, + expr: p.BitwiseNot, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs), - enclosing_prec, PREC_UNARY) - - def map_bitwise_or(self, expr, enclosing_prec, *args, **kwargs): + "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, + PREC_UNARY, + ) + + def map_bitwise_or( + self, expr: p.BitwiseOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " | ", expr.children, PREC_BITWISE_OR, *args, **kwargs), - enclosing_prec, PREC_BITWISE_OR) - - def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" | ", expr.children, PREC_BITWISE_OR, *args, **kwargs), + enclosing_prec, + PREC_BITWISE_OR, + ) + + def map_bitwise_xor( + self, + expr: p.BitwiseXor, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), - enclosing_prec, PREC_BITWISE_XOR) - - def map_bitwise_and(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), + enclosing_prec, + PREC_BITWISE_XOR, + ) + + def map_bitwise_and( + self, + expr: p.BitwiseAnd, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " & ", expr.children, PREC_BITWISE_AND, *args, **kwargs), - enclosing_prec, PREC_BITWISE_AND) - - def map_comparison(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" & ", expr.children, PREC_BITWISE_AND, *args, **kwargs), + enclosing_prec, + PREC_BITWISE_AND, + ) + + def map_comparison( + self, expr: p.Comparison, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s %s %s", - self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), - expr.operator, - self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), - enclosing_prec, PREC_COMPARISON) - - def map_logical_not(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s %s %s", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + expr.operator, + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs), + ), + enclosing_prec, + PREC_COMPARISON, + ) + + def map_logical_not( + self, + expr: p.LogicalNot, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), - enclosing_prec, PREC_UNARY) - - def map_logical_or(self, expr, enclosing_prec, *args, **kwargs): + "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, + PREC_UNARY, + ) + + def map_logical_or( + self, expr: p.LogicalOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_OR) - - def map_logical_and(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), + enclosing_prec, + PREC_LOGICAL_OR, + ) + + def map_logical_and( + self, + expr: p.LogicalAnd, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_AND) - - def map_list(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), + enclosing_prec, + PREC_LOGICAL_AND, + ) + + def map_list( + self, + expr: list[ExpressionT], + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.format( - "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs)) + "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs) + ) map_vector = map_list - def map_tuple(self, expr, enclosing_prec, *args, **kwargs): + def map_tuple( + self, + expr: tuple[ExpressionT, ...], + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: el_str = ", ".join( - self.rec(child, PREC_NONE, *args, **kwargs) for child in expr) + self.rec(child, PREC_NONE, *args, **kwargs) for child in expr + ) if len(expr) == 1: el_str += "," return f"({el_str})" - def map_numpy_array(self, expr, enclosing_prec, *args, **kwargs): + def map_numpy_array( + self, + expr: np.ndarray, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: import numpy str_array = numpy.zeros(expr.shape, dtype="object") @@ -345,68 +610,102 @@ class StringifyMapper(Mapper): if len(expr.shape) == 1 and max_length < 15: return "array({})".format(", ".join(str_array)) else: - lines = [" {}: {}\n".format( - ",".join(str(i_i) for i_i in i), val) - for i, val in numpy.ndenumerate(str_array)] + lines = [ + " {}: {}\n".format(",".join(str(i_i) for i_i in i), val) + for i, val in numpy.ndenumerate(str_array) + ] if max_length > 70: - splitter = " " + "-"*75 + "\n" + splitter = " " + "-" * 75 + "\n" return "array(\n{})".format(splitter.join(lines)) else: return "array(\n{})".format("".join(lines)) - def map_multivector(self, expr, enclosing_prec, *args, **kwargs): + def map_multivector( + self, + expr: MultiVector, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return expr.stringify(self.rec, enclosing_prec, *args, **kwargs) - def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs): + def map_common_subexpression( + self, + expr: p.CommonSubexpression, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: from pymbolic.primitives import CommonSubexpression + if type(expr) is CommonSubexpression: type_name = "CSE" else: type_name = type(expr).__name__ - return self.format("%s(%s)", - type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs)) - - def map_if(self, expr, enclosing_prec, *args, **kwargs): - return self.parenthesize_if_needed( - "{} if {} else {}".format( - self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)), - enclosing_prec, PREC_IF) + return self.format( + "%s(%s)", type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs) + ) - def map_if_positive(self, expr, enclosing_prec, *args, **kwargs): + def map_if( + self, expr: p.If, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - "{} if {} > 0 else {}".format( - self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.criterion, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)), - enclosing_prec, PREC_IF) - - def map_min(self, expr, enclosing_prec, *args, **kwargs): + "{} if {} else {}".format( + self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs), + ), + enclosing_prec, + PREC_IF, + ) + + def map_min( + self, expr: p.Min, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: what = type(expr).__name__.lower() - return self.format("%s(%s)", - what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs)) - - map_max = map_min - - def map_derivative(self, expr, enclosing_prec, *args, **kwargs): + return self.format( + "%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_max( + self, expr: p.Max, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + what = type(expr).__name__.lower() + return self.format( + "%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_derivative( + self, expr: p.Derivative, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: derivs = " ".join(f"d/d{v}" for v in expr.variables) return "{} {}".format( - derivs, - self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)) - - def map_substitution(self, expr, enclosing_prec, *args, **kwargs): + derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs) + ) + + def map_substitution( + self, + expr: p.Substitution, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: substs = ", ".join( - "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) - for name, val in zip(expr.variables, expr.values)) + "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) + for name, val in zip(expr.variables, expr.values) + ) - return "[%s]{%s}" % ( - self.rec(expr.child, PREC_NONE, *args, **kwargs), - substs) + return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs) - def map_slice(self, expr, enclosing_prec, *args, **kwargs): + def map_slice( + self, expr: p.Slice, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: children = [] for child in expr.children: if child is None: @@ -415,15 +714,17 @@ class StringifyMapper(Mapper): children.append(self.rec(child, PREC_NONE, *args, **kwargs)) return self.parenthesize_if_needed( - self.join(":", children), - enclosing_prec, PREC_NONE) + ":".join(children), enclosing_prec, PREC_NONE + ) - def map_nan(self, expr, enclosing_prec, *args, **kwargs): + def map_nan( + self, expr: p.NaN, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return "NaN" # }}} - def __call__(self, expr, prec=PREC_NONE, *args, **kwargs): + def __call__(self, expr, prec=PREC_NONE, *args: P.args, **kwargs: P.kwargs) -> str: """Return a string corresponding to *expr*. If the enclosing precedence level *prec* is higher than *prec* (see :ref:`prec-constants`), parenthesize the result. @@ -437,15 +738,17 @@ class CachedStringifyMapper(StringifyMapper, CachedMapper): StringifyMapper.__init__(self) CachedMapper.__init__(self) - def __call__(self, expr, prec=PREC_NONE, *args, **kwargs): + def __call__(self, expr, prec=PREC_NONE, *args: P.args, **kwargs: P.kwargs) -> str: return CachedMapper.__call__(expr, prec, *args, **kwargs) + # }}} # {{{ cse-splitting stringifier -class CSESplittingStringifyMapperMixin: + +class CSESplittingStringifyMapperMixin(Mapper[str, Concatenate[int, P]]): """A :term:`mix-in` for subclasses of :class:`StringifyMapper` that collects "variable assignments" for @@ -469,44 +772,45 @@ class CSESplittingStringifyMapperMixin: See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example of the use of this mix-in. """ - def __init__(self): + + cse_to_name: dict[ExpressionT, str] + cse_names: set[str] + cse_name_list: list[tuple[str, str]] + + def __init__(self) -> None: self.cse_to_name = {} self.cse_names = set() self.cse_name_list = [] super().__init__() - def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs): - # This is here for compatibility, in case the constructor did not get called. - try: - self.cse_to_name # noqa: B018 - except AttributeError: - from warnings import warn - warn("Constructor of CSESplittingStringifyMapperMixin did not get " - "called. This is deprecated and will stop working in 2022.", - DeprecationWarning, stacklevel=2) - - self.cse_to_name = {} - self.cse_names = set() - self.cse_name_list = [] - + def map_common_subexpression( + self, + expr: p.CommonSubexpression, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: try: cse_name = self.cse_to_name[expr.child] except KeyError: str_child = self.rec(expr.child, PREC_NONE, *args, **kwargs) if expr.prefix is not None: + def generate_cse_names(): yield expr.prefix i = 2 while True: yield expr.prefix + f"_{i}" i += 1 + else: + def generate_cse_names(): i = 0 while True: - yield "CSE"+str(i) + yield "CSE" + str(i) i += 1 for cse_name in generate_cse_names(): @@ -519,51 +823,63 @@ class CSESplittingStringifyMapperMixin: return cse_name - def get_cse_strings(self): - return [f"{cse_name} : {cse_str}" - for cse_name, cse_str in - sorted(getattr(self, "cse_name_list", []))] + def get_cse_strings(self) -> list[str]: + return [ + f"{cse_name} : {cse_str}" + for cse_name, cse_str in sorted(getattr(self, "cse_name_list", [])) + ] + # }}} # {{{ sorting stringifier -class SortingStringifyMapper(StringifyMapper): + +class SortingStringifyMapper(StringifyMapper[P]): def __init__(self, reverse=True): super().__init__() self.reverse = reverse - def map_sum(self, expr, enclosing_prec, *args, **kwargs): + def map_sum( + self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: entries = [self.rec(i, PREC_SUM, *args, **kwargs) for i in expr.children] entries.sort(reverse=self.reverse) - return self.parenthesize_if_needed( - self.join(" + ", entries), - enclosing_prec, PREC_SUM) + return self.parenthesize_if_needed("+".join(entries), enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec, *args, **kwargs): + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: entries = [self.rec(i, PREC_PRODUCT, *args, **kwargs) for i in expr.children] entries.sort(reverse=self.reverse) return self.parenthesize_if_needed( - self.join("*", entries), - enclosing_prec, PREC_PRODUCT) + "*".join(entries), enclosing_prec, PREC_PRODUCT + ) + # }}} # {{{ simplifying, sorting stringifier + class SimplifyingSortingStringifyMapper(StringifyMapper): def __init__(self, reverse=True): super().__init__() self.reverse = reverse - def map_sum(self, expr, enclosing_prec, *args, **kwargs): + def map_sum( + self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: def get_neg_product(expr): from pymbolic.primitives import Product, is_zero - if isinstance(expr, Product) \ - and len(expr.children) and is_zero(expr.children[0]+1): + if ( + isinstance(expr, Product) + and len(expr.children) + and is_zero(expr.children[0] + 1) + ): if len(expr.children) == 2: # only the minus sign and the other child return expr.children[1] @@ -583,29 +899,32 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): positives.append(self.rec(ch, PREC_SUM, *args, **kwargs)) positives.sort(reverse=self.reverse) - positives = " + ".join(positives) + positives_str = " + ".join(positives) negatives.sort(reverse=self.reverse) - negatives = self.join("", - [self.format(" - %s", entry) for entry in negatives]) + negatives_str = "".join(self.format(" - %s", entry) for entry in negatives) - result = positives + negatives + result = positives_str + negatives_str return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec, *args, **kwargs): + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: entries = [] i = 0 from pymbolic.primitives import is_zero while i < len(expr.children): child = expr.children[i] - if False and is_zero(child+1) and i+1 < len(expr.children): + if False and is_zero(child + 1) and i + 1 < len(expr.children): # NOTE: That space needs to be there. # Otherwise two unary minus signs merge into a pre-decrement. entries.append( - self.format( - "- %s", self.rec( - expr.children[i+1], PREC_UNARY, *args, **kwargs))) + self.format( + "- %s", + self.rec(expr.children[i + 1], PREC_UNARY, *args, **kwargs), + ) + ) i += 2 else: entries.append(self.rec(child, PREC_PRODUCT, *args, **kwargs)) @@ -616,127 +935,226 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT) + # }}} # {{{ latex stringifier -class LaTeXMapper(StringifyMapper): +class LaTeXMapper(StringifyMapper): COMPARISON_OP_TO_LATEX: ClassVar[dict[str, str]] = { "==": r"=", "!=": r"\ne", "<=": r"\le", ">=": r"\ge", - "<": r"<", - ">": r">", - } - - def map_remainder(self, expr, enclosing_prec, *args, **kwargs): - return self.format(r"(%s \bmod %s)", - self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), - self.rec(expr.denominator, PREC_POWER, *args, **kwargs)), + "<": r"<", + ">": r">", + } - def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): + def map_remainder( + self, expr: p.Remainder, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.format( + r"(%s \bmod %s)", + self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), + self.rec(expr.denominator, PREC_POWER, *args, **kwargs), + ) + + def map_left_shift( + self, expr: p.LeftShift, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format(r"%s \ll %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): + self.format( + r"%s \ll %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_right_shift( + self, + expr: p.RightShift, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.format(r"%s \gg %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs): + self.format( + r"%s \gg %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_bitwise_xor( + self, + expr: p.BitwiseXor, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), - enclosing_prec, PREC_BITWISE_XOR) - - def map_product(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec( + r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs + ), + enclosing_prec, + PREC_BITWISE_XOR, + ) + + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs), - enclosing_prec, PREC_PRODUCT) - - def map_power(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_power( + self, expr: p.Power, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("{%s}^{%s}", - self.rec(expr.base, PREC_NONE, *args, **kwargs), - self.rec(expr.exponent, PREC_NONE, *args, **kwargs)), - enclosing_prec, PREC_NONE) - - def map_min(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "{%s}^{%s}", + self.rec(expr.base, PREC_NONE, *args, **kwargs), + self.rec(expr.exponent, PREC_NONE, *args, **kwargs), + ), + enclosing_prec, + PREC_NONE, + ) + + def map_min( + self, expr: p.Min, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: from pytools import is_single_valued + if is_single_valued(expr.children): return self.rec(expr.children[0], enclosing_prec) what = type(expr).__name__.lower() - return self.format(r"\%s(%s)", - what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs)) - - def map_max(self, expr, enclosing_prec): - return self.map_min(expr, enclosing_prec) + return self.format( + r"\%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_max( + self, expr: p.Max, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + from pytools import is_single_valued - def map_floor_div(self, expr, enclosing_prec, *args, **kwargs): - return self.format(r"\lfloor {%s} / {%s} \rfloor", - self.rec(expr.numerator, PREC_NONE, *args, **kwargs), - self.rec(expr.denominator, PREC_NONE, *args, **kwargs)) + if is_single_valued(expr.children): + return self.rec(expr.children[0], enclosing_prec) - def map_subscript(self, expr, enclosing_prec, *args, **kwargs): + what = type(expr).__name__.lower() + return self.format( + r"\%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_floor_div( + self, expr: p.FloorDiv, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.format( + r"\lfloor {%s} / {%s} \rfloor", + self.rec(expr.numerator, PREC_NONE, *args, **kwargs), + self.rec(expr.denominator, PREC_NONE, *args, **kwargs), + ) + + def map_subscript( + self, expr: p.Subscript, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: if isinstance(expr.index, tuple): index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs) else: index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs) - return self.format("{%s}_{%s}", - self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), - index_str) - - def map_logical_not(self, expr, enclosing_prec, *args, **kwargs): + return self.format( + "{%s}_{%s}", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), index_str + ) + + def map_logical_not( + self, + expr: p.LogicalNot, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), - enclosing_prec, PREC_UNARY) - - def map_logical_or(self, expr, enclosing_prec, *args, **kwargs): + r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, + PREC_UNARY, + ) + + def map_logical_or( + self, expr: p.LogicalOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_OR) - - def map_logical_and(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), + enclosing_prec, + PREC_LOGICAL_OR, + ) + + def map_logical_and( + self, + expr: p.LogicalAnd, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_AND) - - def map_comparison(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec( + r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs + ), + enclosing_prec, + PREC_LOGICAL_AND, + ) + + def map_comparison( + self, expr: p.Comparison, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s %s %s", - self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), - self.COMPARISON_OP_TO_LATEX[expr.operator], - self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), - enclosing_prec, PREC_COMPARISON) - - def map_substitution(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s %s %s", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + self.COMPARISON_OP_TO_LATEX[expr.operator], + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs), + ), + enclosing_prec, + PREC_COMPARISON, + ) + + def map_substitution( + self, + expr: p.Substitution, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: substs = ", ".join( - "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) - for name, val in zip(expr.variables, expr.values)) + "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) + for name, val in zip(expr.variables, expr.values) + ) - return self.format(r"[%s]\{%s\}", - self.rec(expr.child, PREC_NONE, *args, **kwargs), - substs) + return self.format( + r"[%s]\{%s\}", self.rec(expr.child, PREC_NONE, *args, **kwargs), substs + ) - def map_derivative(self, expr, enclosing_prec, *args, **kwargs): - derivs = " ".join( - r"\frac{\partial}{\partial %s}" % v - for v in expr.variables) + def map_derivative( + self, expr: p.Derivative, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + derivs = " ".join(r"\frac{\partial}{\partial %s}" % v for v in expr.variables) + + return self.format( + "%s %s", derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs) + ) - return self.format("%s %s", - derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)) # }}} diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py index 8cdb3e1..0418634 100644 --- a/pymbolic/mapper/substitutor.py +++ b/pymbolic/mapper/substitutor.py @@ -4,7 +4,16 @@ .. autofunction:: make_subst_func .. autofunction:: substitute +.. autoclass:: Callable[[AlgebraicLeaf], ExpressionT | None] + +References +---------- + +.. class:: SupportsGetItem + + A protocol with a ``__getitem__`` method. """ + from __future__ import annotations @@ -29,12 +38,19 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Any, Callable + +from useful_types import SupportsGetItem, SupportsItems from pymbolic.mapper import CachedIdentityMapper, IdentityMapper +from pymbolic.primitives import AlgebraicLeaf +from pymbolic.typing import ExpressionT -class SubstitutionMapper(IdentityMapper): - def __init__(self, subst_func): +class SubstitutionMapper(IdentityMapper[[]]): + def __init__( + self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None] + ) -> None: self.subst_func = subst_func def map_variable(self, expr): @@ -59,17 +75,26 @@ class SubstitutionMapper(IdentityMapper): return IdentityMapper.map_lookup(self, expr) -class CachedSubstitutionMapper(CachedIdentityMapper, - SubstitutionMapper): - def __init__(self, subst_func): - CachedIdentityMapper.__init__(self) +class CachedSubstitutionMapper(CachedIdentityMapper[[]], SubstitutionMapper): + def __init__( + self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None] + ) -> None: + # FIXME Mypy says: + # error: Argument 1 to "__init__" of "CachedMapper" has incompatible type + # "CachedSubstitutionMapper"; expected "CachedMapper[ResultT, P]" [arg-type] + # This seems spurious? + CachedIdentityMapper.__init__(self) # type: ignore[arg-type] SubstitutionMapper.__init__(self, subst_func) -def make_subst_func(variable_assignments): +def make_subst_func( + # "Any" here avoids the whole Mapping variance disaster + # e.g. https://github.com/python/typing/issues/445 + variable_assignments: SupportsGetItem[Any, ExpressionT], +) -> Callable[[AlgebraicLeaf], ExpressionT | None]: import pymbolic.primitives as primitives - def subst_func(var): + def subst_func(var: AlgebraicLeaf) -> ExpressionT | None: try: return variable_assignments[var] except KeyError: @@ -84,15 +109,23 @@ def make_subst_func(variable_assignments): return subst_func -def substitute(expression, variable_assignments=None, - mapper_cls=CachedSubstitutionMapper, **kwargs): +def substitute( + expression: ExpressionT, + variable_assignments: SupportsItems[AlgebraicLeaf | str, ExpressionT] | None = None, + mapper_cls=CachedSubstitutionMapper, + **kwargs: ExpressionT, +): """ :arg mapper_cls: A :class:`type` of the substitution mapper whose instance applies the substitution. """ if variable_assignments is None: - variable_assignments = {} - variable_assignments = variable_assignments.copy() - variable_assignments.update(kwargs) + # "Any" here avoids pointless grief about variance + # e.g. https://github.com/python/typing/issues/445 + v_ass_copied: dict[Any, ExpressionT] = {} + else: + v_ass_copied = dict(variable_assignments.items()) + + v_ass_copied.update(kwargs) - return mapper_cls(make_subst_func(variable_assignments))(expression) + return mapper_cls(make_subst_func(v_ass_copied))(expression) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 1b4977a..251adc1 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -38,7 +38,7 @@ from typing import ( from warnings import warn from immutabledict import immutabledict -from typing_extensions import TypeIs, dataclass_transform +from typing_extensions import TypeAlias, TypeIs, dataclass_transform from . import traits from .typing import ArithmeticExpressionT, ExpressionT, NumberT, ScalarT @@ -1650,6 +1650,12 @@ class Derivative(Expression): variables: tuple[str, ...] +SliceChildrenT: TypeAlias = (tuple[()] + | tuple[ExpressionT | None] + | tuple[ExpressionT | None, ExpressionT | None] + | tuple[ExpressionT | None, ExpressionT | None, ExpressionT | None]) + + @expr_dataclass() class Slice(Expression): """A slice expression as in a[1:7]. @@ -1661,10 +1667,7 @@ class Slice(Expression): .. autoproperty:: step """ - children: (tuple[()] - | tuple[ExpressionT] - | tuple[ExpressionT, ExpressionT] - | tuple[ExpressionT, ExpressionT, ExpressionT]) + children: SliceChildrenT def __bool__(self): return True diff --git a/pymbolic/typing.py b/pymbolic/typing.py index fce695d..c1ecaa0 100644 --- a/pymbolic/typing.py +++ b/pymbolic/typing.py @@ -4,18 +4,23 @@ Typing helpers -------------- -.. autodata:: BoolT -.. autodata:: NumberT -.. autodata:: ScalarT -.. autodata:: ArithmeticExpressionT +.. autoclass:: BoolT +.. autoclass:: NumberT +.. autoclass:: ScalarT +.. autoclass:: ArithmeticExpressionT A narrower type alias than :class:`ExpressionT` that is returned by arithmetic operators, to allow continue doing arithmetic with the result of arithmetic. - > +.. autoclass:: ExpressionT -.. autodata:: ExpressionT +.. currentmodule:: pymbolic.typing + +.. autoclass:: ArithmeticOrExpressionT + + A type variable that can be either :data:`ArithmeticExpressionT` + or :data:`ExpressionT`. """ from __future__ import annotations @@ -80,6 +85,11 @@ ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"] ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple["ExpressionT", ...]] +ArithmeticOrExpressionT = TypeVar( + "ArithmeticOrExpressionT", + ArithmeticExpressionT, + ExpressionT) + T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index 6d18ba8..ea60e24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,9 @@ dependencies = [ "astunparse; python_version<='3.9'", "immutabledict", "pytools>=2022.1.14", - # for dataclass_transform, TypeAlias - "typing-extensions>=4", + # for dataclass_transform, TypeAlias, deprecated + "typing-extensions>=4.5", + "useful-types", ] [project.optional-dependencies] -- GitLab