diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index dd673c2ead541411523eb8988df8d36c910e1ed1..700d4017940e72276d602d134f8e32ca86f41b1a 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -388,11 +388,11 @@ class CombineMapper(RecursiveMapper): map_tuple = map_list def map_numpy_array(self, expr, *args, **kwargs): - return self.combine(self.rec(el) for el in expr.flat) + return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat) def map_multivector(self, expr, *args, **kwargs): return self.combine( - self.rec(coeff) + self.rec(coeff, *args, **kwargs) for bits, coeff in expr.data.items()) def map_common_subexpression(self, expr, *args, **kwargs): @@ -400,15 +400,15 @@ class CombineMapper(RecursiveMapper): def map_if_positive(self, expr, *args, **kwargs): return self.combine([ - self.rec(expr.criterion), - self.rec(expr.then), - self.rec(expr.else_)]) + self.rec(expr.criterion, *args, **kwargs), + self.rec(expr.then, *args, **kwargs), + self.rec(expr.else_, *args, **kwargs)]) def map_if(self, expr, *args, **kwargs): return self.combine([ - self.rec(expr.condition), - self.rec(expr.then), - self.rec(expr.else_)]) + self.rec(expr.condition, *args, **kwargs), + self.rec(expr.then, *args, **kwargs), + self.rec(expr.else_, *args, **kwargs)]) class CachedCombineMapper(CachedMapper, CombineMapper): diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index f784f017426e44923e10ee6cdbb5e4f2cae47dc4..c75128d3ff917067e1d432037a6074f2127b39f6 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -70,7 +70,7 @@ class DependencyMapper(CSECachingMapperMixin, Collector): def map_call(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( - [self.rec(child) for child in expr.parameters]) + [self.rec(child, *args, **kwargs) for child in expr.parameters]) elif self.include_calls: return {expr} else: @@ -79,8 +79,9 @@ class DependencyMapper(CSECachingMapperMixin, Collector): def map_call_with_kwargs(self, expr, *args, **kwargs): if self.include_calls == "descend_args": return self.combine( - [self.rec(child) for child in expr.parameters] - + [self.rec(val) for name, val in expr.kw_parameters.items()] + [self.rec(child, *args, **kwargs) for child in expr.parameters] + + [self.rec(val, *args, **kwargs) for name, val in + expr.kw_parameters.items()] ) elif self.include_calls: return {expr} @@ -107,7 +108,7 @@ class DependencyMapper(CSECachingMapperMixin, Collector): def map_slice(self, expr, *args, **kwargs): return self.combine( - [self.rec(child) for child in expr.children + [self.rec(child, *args, **kwargs) for child in expr.children if child is not None]) def map_nan(self, expr, *args, **kwargs):