From e1f8a984df47c0f2f81ab9a0fe301629840d9614 Mon Sep 17 00:00:00 2001 From: Nicholas Christensen <njchris2@illinois.edu> Date: Sat, 1 Oct 2022 19:34:28 -0500 Subject: [PATCH] Allow variable number of args to mapping functions --- pymbolic/mapper/__init__.py | 2 +- pymbolic/mapper/dependency.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index f00f48e..d1dc525 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -434,7 +434,7 @@ class Collector(CombineMapper): from functools import reduce return reduce(operator.or_, values, set()) - def map_constant(self, expr): + def map_constant(self, expr, *args, **kwargs): return set() map_variable = map_constant diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index b40f8c1..20d502a 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -64,19 +64,19 @@ class DependencyMapper(CSECachingMapperMixin, Collector): self.include_cses = include_cses - def map_variable(self, expr): + def map_variable(self, expr, *args, **kwargs): return {expr} - def map_call(self, expr): + def map_call(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( [self.rec(child) for child in expr.parameters]) elif self.include_calls: return {expr} else: - return super().map_call(expr) + return super().map_call(expr, *args, **kwargs) - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( [self.rec(child) for child in expr.parameters] @@ -87,13 +87,13 @@ class DependencyMapper(CSECachingMapperMixin, Collector): else: return super().map_call_with_kwargs(expr) - def map_lookup(self, expr): + def map_lookup(self, expr, *args, **kwargs): if self.include_lookups: return {expr} else: return super().map_lookup(expr) - def map_subscript(self, expr): + def map_subscript(self, expr, *args, **kwargs): if self.include_subscripts: return {expr} else: @@ -103,14 +103,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector): if self.include_cses: return {expr} else: - return Collector.map_common_subexpression(self, expr) + return Collector.map_common_subexpression(self, expr, *args, **kwargs) - def map_slice(self, expr): + def map_slice(self, expr, *args, **kwargs): return self.combine( [self.rec(child) for child in expr.children if child is not None]) - def map_nan(self, expr): + def map_nan(self, expr, *args, **kwargs): return set() -- GitLab