Skip to content
Snippets Groups Projects
Commit e1f8a984 authored by Nicholas Christensen's avatar Nicholas Christensen Committed by Andreas Klöckner
Browse files

Allow variable number of args to mapping functions

parent c58dfaa7
No related branches found
No related tags found
No related merge requests found
...@@ -434,7 +434,7 @@ class Collector(CombineMapper): ...@@ -434,7 +434,7 @@ class Collector(CombineMapper):
from functools import reduce from functools import reduce
return reduce(operator.or_, values, set()) return reduce(operator.or_, values, set())
def map_constant(self, expr): def map_constant(self, expr, *args, **kwargs):
return set() return set()
map_variable = map_constant map_variable = map_constant
......
...@@ -64,19 +64,19 @@ class DependencyMapper(CSECachingMapperMixin, Collector): ...@@ -64,19 +64,19 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
self.include_cses = include_cses self.include_cses = include_cses
def map_variable(self, expr): def map_variable(self, expr, *args, **kwargs):
return {expr} return {expr}
def map_call(self, expr): def map_call(self, expr, *args, **kwargs):
if self.include_calls == "descend_args": if self.include_calls == "descend_args":
return self.combine( return self.combine(
[self.rec(child) for child in expr.parameters]) [self.rec(child) for child in expr.parameters])
elif self.include_calls: elif self.include_calls:
return {expr} return {expr}
else: else:
return super().map_call(expr) return super().map_call(expr, *args, **kwargs)
def map_call_with_kwargs(self, expr): def map_call_with_kwargs(self, expr, *args, **kwargs):
if self.include_calls == "descend_args": if self.include_calls == "descend_args":
return self.combine( return self.combine(
[self.rec(child) for child in expr.parameters] [self.rec(child) for child in expr.parameters]
...@@ -87,13 +87,13 @@ class DependencyMapper(CSECachingMapperMixin, Collector): ...@@ -87,13 +87,13 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
else: else:
return super().map_call_with_kwargs(expr) return super().map_call_with_kwargs(expr)
def map_lookup(self, expr): def map_lookup(self, expr, *args, **kwargs):
if self.include_lookups: if self.include_lookups:
return {expr} return {expr}
else: else:
return super().map_lookup(expr) return super().map_lookup(expr)
def map_subscript(self, expr): def map_subscript(self, expr, *args, **kwargs):
if self.include_subscripts: if self.include_subscripts:
return {expr} return {expr}
else: else:
...@@ -103,14 +103,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector): ...@@ -103,14 +103,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector):
if self.include_cses: if self.include_cses:
return {expr} return {expr}
else: else:
return Collector.map_common_subexpression(self, expr) return Collector.map_common_subexpression(self, expr, *args, **kwargs)
def map_slice(self, expr): def map_slice(self, expr, *args, **kwargs):
return self.combine( return self.combine(
[self.rec(child) for child in expr.children [self.rec(child) for child in expr.children
if child is not None]) if child is not None])
def map_nan(self, expr): def map_nan(self, expr, *args, **kwargs):
return set() return set()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment