diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index 5be6c9c0ed2823acdf674ec64dc65c5a2f720bdc..618fb192e62080b8961fdd468a125b156c002037 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -31,7 +31,7 @@ from pymbolic.mapper import ( Collector as CollectorBase, IdentityMapper as IdentityMapperBase, WalkMapper as WalkMapperBase, - CachingMapperMixin + CachedMapper, ) from pymbolic.mapper.constant_folder import ( ConstantFoldingMapper as ConstantFoldingMapperBase) @@ -170,7 +170,11 @@ class Dimensionalizer(EvaluationMapper): # {{{ derivative binder -class DerivativeSourceAndNablaComponentCollector(CachingMapperMixin, Collector): +class DerivativeSourceAndNablaComponentCollector(CachedMapper, Collector): + def __init__(self) -> None: + Collector.__init__(self) + CachedMapper.__init__(self) + def map_nabla(self, expr): raise RuntimeError("DerivativeOccurrenceMapper must be invoked after " "Dimensionalizer--Nabla found, not allowed") diff --git a/pymbolic/mapper/flop_counter.py b/pymbolic/mapper/flop_counter.py index 8dfbd45e31fce4ea5f25a131ac0fd8c6764c70c9..ab0ab965e84c37b31c31204ac8e31d9f23e27762 100644 --- a/pymbolic/mapper/flop_counter.py +++ b/pymbolic/mapper/flop_counter.py @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pymbolic.mapper import CombineMapper, CachingMapperMixin +from pymbolic.mapper import CombineMapper, CachedMapper class FlopCounterBase(CombineMapper): @@ -55,9 +55,10 @@ class FlopCounterBase(CombineMapper): self.rec(expr.else_)) -class FlopCounter(FlopCounterBase, CachingMapperMixin): - def map_common_subexpression_uncached(self, expr): - return self.rec(expr.child) +class FlopCounter(CachedMapper, FlopCounterBase): + def __init__(self) -> None: + FlopCounterBase.__init__(self) + CachedMapper.__init__(self) class CSEAwareFlopCounter(FlopCounterBase): @@ -70,6 +71,7 @@ class CSEAwareFlopCounter(FlopCounterBase): reuse may take place. """ def __init__(self): + super().__init__() self.cse_seen_set = set() def map_common_subexpression(self, expr):