diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 13b236afd951c2608e22cfdc4735d75a8761629d..73760978c21ee92625bb396edeb078c2be790d4d 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -345,10 +345,13 @@ class InstructionBase(ImmutableRecord): """ raise NotImplementedError - def with_transformed_expressions(self, f): + def with_transformed_expressions(self, f, assignee_f=None): """Return a new copy of *self* where *f* has been applied to every expression occurring in *self*. *args* will be passed as extra arguments (in addition to the expression) to *f*. + + If *assignee_f* is passed, then left-hand sides of assignments are + passed to it. If it is not given, it defaults to the same as *f*. """ raise NotImplementedError @@ -960,12 +963,15 @@ class Assignment(MultiAssignmentBase): def assignee_subscript_deps(self): return (_get_assignee_subscript_deps(self.assignee),) - def with_transformed_expressions(self, f): + def with_transformed_expressions(self, f, assignee_f=None): + if assignee_f is None: + assignee_f = f + return self.copy( - assignee=f(self.assignee), + assignee=assignee_f(self.assignee), expression=f(self.expression), predicates=frozenset( - f(pred, *args) for pred in self.predicates)) + f(pred) for pred in self.predicates)) # }}} @@ -1115,12 +1121,15 @@ class CallInstruction(MultiAssignmentBase): _get_assignee_subscript_deps(a) for a in self.assignees) - def with_transformed_expressions(self, f, *args): + def with_transformed_expressions(self, f, assignee_f=None): + if assignee_f is None: + assignee_f = f + return self.copy( - assignees=f(self.assignees, *args), - expression=f(self.expression, *args), + assignees=assignee_f(self.assignees), + expression=f(self.expression), predicates=frozenset( - f(pred, *args) for pred in self.predicates)) + f(pred) for pred in self.predicates)) # }}} @@ -1316,14 +1325,14 @@ class CInstruction(InstructionBase): _get_assignee_subscript_deps(a) for a in self.assignees) - def with_transformed_expressions(self, f, *args): + def with_transformed_expressions(self, f, assignee_f=None): return self.copy( iname_exprs=[ - (name, f(expr, *args)) + (name, f(expr)) for name, expr in self.iname_exprs], - assignees=[f(a, *args) for a in self.assignees], + assignees=[assignee_f(a) for a in self.assignees], predicates=frozenset( - f(pred, *args) for pred in self.predicates)) + f(pred) for pred in self.predicates)) # }}} @@ -1358,7 +1367,7 @@ class _DataObliviousInstruction(InstructionBase): def assignee_subscript_deps(self): return frozenset() - def with_transformed_expressions(self, f): + def with_transformed_expressions(self, f, assignee_f=None): return self.copy( predicates=frozenset( f(pred) for pred in self.predicates)) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ccac5e199d2b53e202dd735ffd8dfe20a7dc29a2..d3261b110eef73eb34769e8702af272875613c2c 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -971,7 +971,8 @@ class RuleAwareIdentityMapper(IdentityMapper): # may perform tasks entirely unrelated to subst rules, so # we must map assignees, too. self.map_instruction(kernel, - insn.with_transformed_expressions(self, kernel, insn)) + insn.with_transformed_expressions( + lambda expr: self(expr, kernel, insn))) for insn in kernel.instructions] return kernel.copy(instructions=new_insns) diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 9f426f76bc6902fd09bd7685c73f187df935be1e..b308836c7727564dbfa9625ad39f378e8034c68c 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -229,7 +229,8 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): for insn in kernel.instructions: self.replaced_something = False - insn = insn.with_transformed_expressions(self, kernel, insn) + insn = insn.with_transformed_expressions( + lambda expr: self(expr, kernel, insn)) if self.replaced_something: insn = insn.copy(