From e1f8a984df47c0f2f81ab9a0fe301629840d9614 Mon Sep 17 00:00:00 2001
From: Nicholas Christensen <njchris2@illinois.edu>
Date: Sat, 1 Oct 2022 19:34:28 -0500
Subject: [PATCH] Allow variable number of args to mapping functions

---
 pymbolic/mapper/__init__.py   |  2 +-
 pymbolic/mapper/dependency.py | 18 +++++++++---------
 2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py
index f00f48e..d1dc525 100644
--- a/pymbolic/mapper/__init__.py
+++ b/pymbolic/mapper/__init__.py
@@ -434,7 +434,7 @@ class Collector(CombineMapper):
         from functools import reduce
         return reduce(operator.or_, values, set())
 
-    def map_constant(self, expr):
+    def map_constant(self, expr, *args, **kwargs):
         return set()
 
     map_variable = map_constant
diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py
index b40f8c1..20d502a 100644
--- a/pymbolic/mapper/dependency.py
+++ b/pymbolic/mapper/dependency.py
@@ -64,19 +64,19 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
 
         self.include_cses = include_cses
 
-    def map_variable(self, expr):
+    def map_variable(self, expr, *args, **kwargs):
         return {expr}
 
-    def map_call(self, expr):
+    def map_call(self, expr, *args, **kwargs):
         if self.include_calls == "descend_args":
             return self.combine(
                     [self.rec(child) for child in expr.parameters])
         elif self.include_calls:
             return {expr}
         else:
-            return super().map_call(expr)
+            return super().map_call(expr, *args, **kwargs)
 
-    def map_call_with_kwargs(self, expr):
+    def map_call_with_kwargs(self, expr, *args, **kwargs):
         if self.include_calls == "descend_args":
             return self.combine(
                     [self.rec(child) for child in expr.parameters]
@@ -87,13 +87,13 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
         else:
             return super().map_call_with_kwargs(expr)
 
-    def map_lookup(self, expr):
+    def map_lookup(self, expr, *args, **kwargs):
         if self.include_lookups:
             return {expr}
         else:
             return super().map_lookup(expr)
 
-    def map_subscript(self, expr):
+    def map_subscript(self, expr, *args, **kwargs):
         if self.include_subscripts:
             return {expr}
         else:
@@ -103,14 +103,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
         if self.include_cses:
             return {expr}
         else:
-            return Collector.map_common_subexpression(self, expr)
+            return Collector.map_common_subexpression(self, expr, *args, **kwargs)
 
-    def map_slice(self, expr):
+    def map_slice(self, expr, *args, **kwargs):
         return self.combine(
                 [self.rec(child) for child in expr.children
                     if child is not None])
 
-    def map_nan(self, expr):
+    def map_nan(self, expr, *args, **kwargs):
         return set()
 
 
-- 
GitLab