diff --git a/loopy/symbolic.py b/loopy/symbolic.py index f5cf07b0e1d62212ce36edb48f47eb7de7d31451..11f9d4370d36b2cabccf7690362e7cd2a90ab667 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -69,23 +69,25 @@ import numpy as np # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin(object): - def map_literal(self, expr, *args): + def map_literal(self, expr, *args, **kwargs): return expr - def map_array_literal(self, expr, *args): - return type(expr)(tuple(self.rec(ch, *args) for ch in expr.children)) + def map_array_literal(self, expr, *args, **kwargs): + return type(expr)(tuple(self.rec(ch, *args, **kwargs) + for ch in expr.children)) - def map_group_hw_index(self, expr, *args): + def map_group_hw_index(self, expr, *args, **kwargs): return expr - def map_local_hw_index(self, expr, *args): + def map_local_hw_index(self, expr, *args, **kwargs): return expr - def map_loopy_function_identifier(self, expr, *args): + def map_loopy_function_identifier(self, expr, *args, **kwargs): return expr - def map_reduction(self, expr, *args): - mapped_inames = [self.rec(p.Variable(iname), *args) for iname in expr.inames] + def map_reduction(self, expr, *args, **kwargs): + mapped_inames = [self.rec(p.Variable(iname), *args, **kwargs) + for iname in expr.inames] new_inames = [] for iname, new_sym_iname in zip(expr.inames, mapped_inames): @@ -98,14 +100,14 @@ class IdentityMapperMixin(object): return Reduction( expr.operation, tuple(new_inames), - self.rec(expr.expr, *args), + self.rec(expr.expr, *args, **kwargs), allow_simultaneous=expr.allow_simultaneous) - def map_tagged_variable(self, expr, *args): + def map_tagged_variable(self, expr, *args, **kwargs): # leaf, doesn't change return expr - def map_type_annotation(self, expr, *args): + def map_type_annotation(self, expr, *args, **kwargs): return type(expr)(expr.type, self.rec(expr.child)) map_type_cast = map_type_annotation @@ -129,37 +131,37 @@ class PartialEvaluationMapper( class WalkMapper(WalkMapperBase): - def map_literal(self, expr, *args): - self.visit(expr) + def map_literal(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) - def map_array_literal(self, expr, *args): - if not self.visit(expr): + def map_array_literal(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return for ch in expr.children: - self.rec(ch, *args) + self.rec(ch, *args, **kwargs) - def map_group_hw_index(self, expr, *args): - self.visit(expr) + def map_group_hw_index(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) - def map_local_hw_index(self, expr, *args): - self.visit(expr) + def map_local_hw_index(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) - def map_reduction(self, expr, *args): - if not self.visit(expr): + def map_reduction(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.expr, *args) + self.rec(expr.expr, *args, **kwargs) - def map_type_cast(self, expr, *args): - if not self.visit(expr): + def map_type_cast(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.child, *args) + self.rec(expr.child, *args, **kwargs) map_tagged_variable = WalkMapperBase.map_variable - def map_loopy_function_identifier(self, expr, *args): - self.visit(expr) + def map_loopy_function_identifier(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) map_linear_subscript = WalkMapperBase.map_subscript @@ -171,8 +173,8 @@ class CallbackMapper(CallbackMapperBase, IdentityMapper): class CombineMapper(CombineMapperBase): - def map_reduction(self, expr): - return self.rec(expr.expr) + def map_reduction(self, expr, *args, **kwargs): + return self.rec(expr.expr, *args, **kwargs) map_linear_subscript = CombineMapperBase.map_subscript @@ -262,32 +264,32 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): class DependencyMapper(DependencyMapperBase): - def map_group_hw_index(self, expr): + def map_group_hw_index(self, expr, *args, **kwargs): return set() - def map_local_hw_index(self, expr): + def map_local_hw_index(self, expr, *args, **kwargs): return set() - def map_call(self, expr, *args): + def map_call(self, expr, *args, **kwargs): # Loopy does not have first-class functions. Do not descend # into 'function' attribute of Call. return self.combine( - self.rec(child, *args) for child in expr.parameters) + self.rec(child, *args, **kwargs) for child in expr.parameters) - def map_reduction(self, expr): - deps = self.rec(expr.expr) + def map_reduction(self, expr, *args, **kwargs): + deps = self.rec(expr.expr, *args, **kwargs) return deps - set(p.Variable(iname) for iname in expr.inames) - def map_tagged_variable(self, expr): + def map_tagged_variable(self, expr, *args, **kwargs): return set([expr]) - def map_loopy_function_identifier(self, expr): + def map_loopy_function_identifier(self, expr, *args, **kwargs): return set() map_linear_subscript = DependencyMapperBase.map_subscript - def map_type_cast(self, expr): - return self.rec(expr.child) + def map_type_cast(self, expr, *args, **kwargs): + return self.rec(expr.child, *args, **kwargs) class SubstitutionRuleExpander(IdentityMapper):