diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index d4090e9e714c2fb870afeff3c7a53f34cadcdfd0..76d4f3630fff295d66c13c7d7db0725320a78dcc 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -462,7 +462,8 @@ def to_evaluatable_python_function(expr: ExpressionT, unparse = ast.unparse - dep_mapper = CachedDependencyMapper(composite_leaves=True) + dep_mapper: CachedDependencyMapper[[]] = ( + CachedDependencyMapper(composite_leaves=True)) deps = sorted({dep.name for dep in dep_mapper(expr)}) ast_func = ast.FunctionDef(name=fn_name, diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 2a32cae4aff1abd4d4c20ee423e5701f89a8525c..abb59a197ce31d061c44a3e1ce5e944d9c590148 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -429,7 +429,7 @@ class CachedMapper(Mapper[ResultT, P]): method_name = getattr(expr, "mapper_method", None) if method_name is not None: method = cast( - Callable[Concatenate[ExpressionT, P], ResultT], + Callable[Concatenate[ExpressionT, P], ResultT] | None, getattr(self, method_name, None) ) if method is not None: diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index 89b91db2dd1112fa8a4a345dbe2090cb1a09b509..f246a9fdfcbed45fdaebb07c4d8bb10f01fe68a8 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -29,7 +29,7 @@ THE SOFTWARE. """ from collections.abc import Set -from typing import TypeAlias +from typing import Literal, TypeAlias import pymbolic.primitives as p from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin, P @@ -53,10 +53,10 @@ class DependencyMapper( self, include_subscripts: bool = True, include_lookups: bool = True, - include_calls: bool = True, + include_calls: bool | Literal["descend_args"] = True, include_cses: bool = False, composite_leaves: bool | None = None, - ): + ) -> None: """ :arg composite_leaves: Setting this is equivalent to setting all preceding ``include_*`` flags. @@ -66,6 +66,7 @@ class DependencyMapper( include_subscripts = False include_lookups = False include_calls = False + if composite_leaves is True: include_subscripts = True include_lookups = True @@ -76,7 +77,6 @@ class DependencyMapper( self.include_subscripts = include_subscripts self.include_lookups = include_lookups self.include_calls = include_calls - self.include_cses = include_cses def map_variable( @@ -150,15 +150,16 @@ class DependencyMapper( return set() -class CachedDependencyMapper(CachedMapper, DependencyMapper): +class CachedDependencyMapper(CachedMapper[DependenciesT, P], + DependencyMapper[P]): def __init__( self, - include_subscripts=True, - include_lookups=True, - include_calls=True, - include_cses=False, - composite_leaves=None, - ): + include_subscripts: bool = True, + include_lookups: bool = True, + include_calls: bool | Literal["descend_args"] = True, + include_cses: bool = False, + composite_leaves: bool | None = None, + ) -> None: CachedMapper.__init__(self) DependencyMapper.__init__( self,