diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index 40eb651d9de14078752af3ad9d49506486ad5502..08a40c6338235eb1d465cb7007018c7110376641 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -83,6 +83,17 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper): else: return CombineMapper.map_call(self, expr) + def map_call_with_kwargs(self, expr): + if self.include_calls == "descend_args": + return self.combine( + [self.rec(child) for child in expr.parameters] + + [self.rec(val) for name, val in expr.kw_parameters.items()] + ) + elif self.include_calls: + return set([expr]) + else: + return CombineMapper.map_call(self, expr) + def map_lookup(self, expr): if self.include_lookups: return set([expr]) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 48c4740d4b14bd89dfbf3d6a96bc5807b8020801..a8951070a395b4432f5596251bf940584a4a7287 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -255,6 +255,16 @@ def test_mappers(): DependencyMapper()(expr) +def test_func_dep_consistency(): + from pymbolic import var + from pymbolic.mapper.dependency import DependencyMapper + f = var('f') + x = var('x') + dep_map = DependencyMapper(include_calls="descend_args") + assert dep_map(f(x)) == set([x]) + assert dep_map(f(x=x)) == set([x]) + + # {{{ geometric algebra @pytest.mark.parametrize("dims", [2, 3, 4, 5])