diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index f00f48e0adcd7975863473e8db749eff0ee855d5..d1dc5253d4c162f8e3f5a529c4fa175a21f746bc 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -434,7 +434,7 @@ class Collector(CombineMapper): from functools import reduce return reduce(operator.or_, values, set()) - def map_constant(self, expr): + def map_constant(self, expr, *args, **kwargs): return set() map_variable = map_constant diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index b40f8c133340dd51f6da5a5ac8da745c91fc0dad..20d502a7ce03455ba4579ea7590512d0691f1c0d 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -64,19 +64,19 @@ class DependencyMapper(CSECachingMapperMixin, Collector): self.include_cses = include_cses - def map_variable(self, expr): + def map_variable(self, expr, *args, **kwargs): return {expr} - def map_call(self, expr): + def map_call(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( [self.rec(child) for child in expr.parameters]) elif self.include_calls: return {expr} else: - return super().map_call(expr) + return super().map_call(expr, *args, **kwargs) - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( [self.rec(child) for child in expr.parameters] @@ -87,13 +87,13 @@ class DependencyMapper(CSECachingMapperMixin, Collector): else: return super().map_call_with_kwargs(expr) - def map_lookup(self, expr): + def map_lookup(self, expr, *args, **kwargs): if self.include_lookups: return {expr} else: return super().map_lookup(expr) - def map_subscript(self, expr): + def map_subscript(self, expr, *args, **kwargs): if self.include_subscripts: return {expr} else: @@ -103,14 +103,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector): if self.include_cses: return {expr} else: - return Collector.map_common_subexpression(self, expr) + return Collector.map_common_subexpression(self, expr, *args, **kwargs) - def map_slice(self, expr): + def map_slice(self, expr, *args, **kwargs): return self.combine( [self.rec(child) for child in expr.children if child is not None]) - def map_nan(self, expr): + def map_nan(self, expr, *args, **kwargs): return set()