Skip to content
Snippets Groups Projects
Commit c81f687f authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge branch 'symbolicfix' into 'master'

passing parameters in MapperClasses fix. Contraction with pymbolic

See merge request inducer/loopy!335
parents c6976a42 ed677a47
No related branches found
No related tags found
No related merge requests found
......@@ -69,23 +69,25 @@ import numpy as np
# {{{ mappers with support for loopy-specific primitives
class IdentityMapperMixin(object):
def map_literal(self, expr, *args):
def map_literal(self, expr, *args, **kwargs):
return expr
def map_array_literal(self, expr, *args):
return type(expr)(tuple(self.rec(ch, *args) for ch in expr.children))
def map_array_literal(self, expr, *args, **kwargs):
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
for ch in expr.children))
def map_group_hw_index(self, expr, *args):
def map_group_hw_index(self, expr, *args, **kwargs):
return expr
def map_local_hw_index(self, expr, *args):
def map_local_hw_index(self, expr, *args, **kwargs):
return expr
def map_loopy_function_identifier(self, expr, *args):
def map_loopy_function_identifier(self, expr, *args, **kwargs):
return expr
def map_reduction(self, expr, *args):
mapped_inames = [self.rec(p.Variable(iname), *args) for iname in expr.inames]
def map_reduction(self, expr, *args, **kwargs):
mapped_inames = [self.rec(p.Variable(iname), *args, **kwargs)
for iname in expr.inames]
new_inames = []
for iname, new_sym_iname in zip(expr.inames, mapped_inames):
......@@ -98,14 +100,14 @@ class IdentityMapperMixin(object):
return Reduction(
expr.operation, tuple(new_inames),
self.rec(expr.expr, *args),
self.rec(expr.expr, *args, **kwargs),
allow_simultaneous=expr.allow_simultaneous)
def map_tagged_variable(self, expr, *args):
def map_tagged_variable(self, expr, *args, **kwargs):
# leaf, doesn't change
return expr
def map_type_annotation(self, expr, *args):
def map_type_annotation(self, expr, *args, **kwargs):
return type(expr)(expr.type, self.rec(expr.child))
map_type_cast = map_type_annotation
......@@ -129,37 +131,37 @@ class PartialEvaluationMapper(
class WalkMapper(WalkMapperBase):
def map_literal(self, expr, *args):
self.visit(expr)
def map_literal(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)
def map_array_literal(self, expr, *args):
if not self.visit(expr):
def map_array_literal(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return
for ch in expr.children:
self.rec(ch, *args)
self.rec(ch, *args, **kwargs)
def map_group_hw_index(self, expr, *args):
self.visit(expr)
def map_group_hw_index(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)
def map_local_hw_index(self, expr, *args):
self.visit(expr)
def map_local_hw_index(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)
def map_reduction(self, expr, *args):
if not self.visit(expr):
def map_reduction(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return
self.rec(expr.expr, *args)
self.rec(expr.expr, *args, **kwargs)
def map_type_cast(self, expr, *args):
if not self.visit(expr):
def map_type_cast(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return
self.rec(expr.child, *args)
self.rec(expr.child, *args, **kwargs)
map_tagged_variable = WalkMapperBase.map_variable
def map_loopy_function_identifier(self, expr, *args):
self.visit(expr)
def map_loopy_function_identifier(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)
map_linear_subscript = WalkMapperBase.map_subscript
......@@ -171,8 +173,8 @@ class CallbackMapper(CallbackMapperBase, IdentityMapper):
class CombineMapper(CombineMapperBase):
def map_reduction(self, expr):
return self.rec(expr.expr)
def map_reduction(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)
map_linear_subscript = CombineMapperBase.map_subscript
......@@ -262,32 +264,32 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
class DependencyMapper(DependencyMapperBase):
def map_group_hw_index(self, expr):
def map_group_hw_index(self, expr, *args, **kwargs):
return set()
def map_local_hw_index(self, expr):
def map_local_hw_index(self, expr, *args, **kwargs):
return set()
def map_call(self, expr, *args):
def map_call(self, expr, *args, **kwargs):
# Loopy does not have first-class functions. Do not descend
# into 'function' attribute of Call.
return self.combine(
self.rec(child, *args) for child in expr.parameters)
self.rec(child, *args, **kwargs) for child in expr.parameters)
def map_reduction(self, expr):
deps = self.rec(expr.expr)
def map_reduction(self, expr, *args, **kwargs):
deps = self.rec(expr.expr, *args, **kwargs)
return deps - set(p.Variable(iname) for iname in expr.inames)
def map_tagged_variable(self, expr):
def map_tagged_variable(self, expr, *args, **kwargs):
return set([expr])
def map_loopy_function_identifier(self, expr):
def map_loopy_function_identifier(self, expr, *args, **kwargs):
return set()
map_linear_subscript = DependencyMapperBase.map_subscript
def map_type_cast(self, expr):
return self.rec(expr.child)
def map_type_cast(self, expr, *args, **kwargs):
return self.rec(expr.child, *args, **kwargs)
class SubstitutionRuleExpander(IdentityMapper):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment