diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py index 7e6b03581e39d03bc06d2f6d37f65a1d4ac6a386..2ad417696fea5d5bd9f200924aa7cf7332964d94 100644 --- a/loopy/transform/batch.py +++ b/loopy/transform/batch.py @@ -64,17 +64,15 @@ class _BatchVariableChanger(RuleAwareIdentityMapper): def needs_batch_subscript(self, name): tv = self.kernel.temporary_variables.get(name) - - if name in self.batch_varying_args: - return True - if not self.sequential: + if self.sequential: + return name in self.batch_varying_args + else: if tv is None: return False if not temp_needs_batching_if_not_sequential(tv, self.batch_varying_args): return False - - return True + return True def map_subscript(self, expr, expn_state): if not self.needs_batch_subscript(expr.aggregate.name):