From 8991e74074a6c02c5445a23b78c95b47ef27dbf2 Mon Sep 17 00:00:00 2001
From: Nicholas Christensen <njchris2@illinois.edu>
Date: Mon, 10 Oct 2022 14:36:07 -0500
Subject: [PATCH] Add args and kwargs to recursive calls that are missing it

---
 pymbolic/mapper/__init__.py   | 16 ++++++++--------
 pymbolic/mapper/dependency.py |  9 +++++----
 2 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py
index dd673c2..700d401 100644
--- a/pymbolic/mapper/__init__.py
+++ b/pymbolic/mapper/__init__.py
@@ -388,11 +388,11 @@ class CombineMapper(RecursiveMapper):
     map_tuple = map_list
 
     def map_numpy_array(self, expr, *args, **kwargs):
-        return self.combine(self.rec(el) for el in expr.flat)
+        return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat)
 
     def map_multivector(self, expr, *args, **kwargs):
         return self.combine(
-                self.rec(coeff)
+                self.rec(coeff, *args, **kwargs)
                 for bits, coeff in expr.data.items())
 
     def map_common_subexpression(self, expr, *args, **kwargs):
@@ -400,15 +400,15 @@ class CombineMapper(RecursiveMapper):
 
     def map_if_positive(self, expr, *args, **kwargs):
         return self.combine([
-            self.rec(expr.criterion),
-            self.rec(expr.then),
-            self.rec(expr.else_)])
+            self.rec(expr.criterion, *args, **kwargs),
+            self.rec(expr.then, *args, **kwargs),
+            self.rec(expr.else_, *args, **kwargs)])
 
     def map_if(self, expr, *args, **kwargs):
         return self.combine([
-            self.rec(expr.condition),
-            self.rec(expr.then),
-            self.rec(expr.else_)])
+            self.rec(expr.condition, *args, **kwargs),
+            self.rec(expr.then, *args, **kwargs),
+            self.rec(expr.else_, *args, **kwargs)])
 
 
 class CachedCombineMapper(CachedMapper, CombineMapper):
diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py
index f784f01..c75128d 100644
--- a/pymbolic/mapper/dependency.py
+++ b/pymbolic/mapper/dependency.py
@@ -70,7 +70,7 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
     def map_call(self, expr, *args, **kwargs):
         if self.include_calls == "descend_args":
             return self.combine(
-                    [self.rec(child) for child in expr.parameters])
+                    [self.rec(child, *args, **kwargs) for child in expr.parameters])
         elif self.include_calls:
             return {expr}
         else:
@@ -79,8 +79,9 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
     def map_call_with_kwargs(self, expr, *args, **kwargs):
         if self.include_calls == "descend_args":
             return self.combine(
-                    [self.rec(child) for child in expr.parameters]
-                    + [self.rec(val) for name, val in expr.kw_parameters.items()]
+                    [self.rec(child, *args, **kwargs) for child in expr.parameters]
+                    + [self.rec(val, *args, **kwargs) for name, val in
+                    expr.kw_parameters.items()]
                     )
         elif self.include_calls:
             return {expr}
@@ -107,7 +108,7 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
 
     def map_slice(self, expr, *args, **kwargs):
         return self.combine(
-                [self.rec(child) for child in expr.children
+                [self.rec(child, *args, **kwargs) for child in expr.children
                     if child is not None])
 
     def map_nan(self, expr, *args, **kwargs):
-- 
GitLab