From 8991e74074a6c02c5445a23b78c95b47ef27dbf2 Mon Sep 17 00:00:00 2001 From: Nicholas Christensen <njchris2@illinois.edu> Date: Mon, 10 Oct 2022 14:36:07 -0500 Subject: [PATCH] Add args and kwargs to recursive calls that are missing it --- pymbolic/mapper/__init__.py | 16 ++++++++-------- pymbolic/mapper/dependency.py | 9 +++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index dd673c2..700d401 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -388,11 +388,11 @@ class CombineMapper(RecursiveMapper): map_tuple = map_list def map_numpy_array(self, expr, *args, **kwargs): - return self.combine(self.rec(el) for el in expr.flat) + return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat) def map_multivector(self, expr, *args, **kwargs): return self.combine( - self.rec(coeff) + self.rec(coeff, *args, **kwargs) for bits, coeff in expr.data.items()) def map_common_subexpression(self, expr, *args, **kwargs): @@ -400,15 +400,15 @@ class CombineMapper(RecursiveMapper): def map_if_positive(self, expr, *args, **kwargs): return self.combine([ - self.rec(expr.criterion), - self.rec(expr.then), - self.rec(expr.else_)]) + self.rec(expr.criterion, *args, **kwargs), + self.rec(expr.then, *args, **kwargs), + self.rec(expr.else_, *args, **kwargs)]) def map_if(self, expr, *args, **kwargs): return self.combine([ - self.rec(expr.condition), - self.rec(expr.then), - self.rec(expr.else_)]) + self.rec(expr.condition, *args, **kwargs), + self.rec(expr.then, *args, **kwargs), + self.rec(expr.else_, *args, **kwargs)]) class CachedCombineMapper(CachedMapper, CombineMapper): diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index f784f01..c75128d 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -70,7 +70,7 @@ class DependencyMapper(CSECachingMapperMixin, Collector): def map_call(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( - [self.rec(child) for child in expr.parameters]) + [self.rec(child, *args, **kwargs) for child in expr.parameters]) elif self.include_calls: return {expr} else: @@ -79,8 +79,9 @@ class DependencyMapper(CSECachingMapperMixin, Collector): def map_call_with_kwargs(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( - [self.rec(child) for child in expr.parameters] - + [self.rec(val) for name, val in expr.kw_parameters.items()] + [self.rec(child, *args, **kwargs) for child in expr.parameters] + + [self.rec(val, *args, **kwargs) for name, val in + expr.kw_parameters.items()] ) elif self.include_calls: return {expr} @@ -107,7 +108,7 @@ class DependencyMapper(CSECachingMapperMixin, Collector): def map_slice(self, expr, *args, **kwargs): return self.combine( - [self.rec(child) for child in expr.children + [self.rec(child, *args, **kwargs) for child in expr.children if child is not None]) def map_nan(self, expr, *args, **kwargs): -- GitLab