diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py index 1dc54f94b3f59af0ebd5a24ae63d06146ad464e2..967e14de692ee96b0c01d2ea5bcf8b411890038b 100644 --- a/loopy/transform/batch.py +++ b/loopy/transform/batch.py @@ -38,16 +38,18 @@ __doc__ = """ class _BatchVariableChanger(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, kernel, batch_varying_args, - batch_iname_expr): + batch_iname_expr, sequential): super(_BatchVariableChanger, self).__init__(rule_mapping_context) self.kernel = kernel self.batch_varying_args = batch_varying_args self.batch_iname_expr = batch_iname_expr + self.sequential = sequential def needs_batch_subscript(self, name): return ( - name in self.kernel.temporary_variables + (not self.sequential + and name in self.kernel.temporary_variables) or name in self.batch_varying_args) @@ -68,14 +70,18 @@ class _BatchVariableChanger(RuleAwareIdentityMapper): return expr.aggregate[self.batch_iname_expr] -def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"): +def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", + sequential=False): """Takes in a kernel that carries out an operation and returns a kernel that carries out a batch of these operations. :arg nbatches: the number of batches. May be a constant non-negative integer or a string, which will be added as an integer argument. - :arg batch_varying_args: a list of argument names that depend vary per-batch. + :arg batch_varying_args: a list of argument names that vary per-batch. Each such variable will have a batch index added. + :arg sequential: A :class:`bool`. If *True*, do not duplicate + temporary variables for each batch. This automatically tags the batch + iname for sequential execution. """ from pymbolic import var @@ -114,26 +120,32 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"): new_args.append(arg) - new_temps = {} - - for temp in six.itervalues(knl.temporary_variables): - new_temps[temp.name] = temp.copy( - shape=(nbatches_expr,) + temp.shape, - dim_tags=("c",) * (len(arg.shape) + 1)) - knl = knl.copy( domains=new_domains, - args=new_args, - temporary_variables=new_temps) + args=new_args) + + if not sequential: + new_temps = {} + + for temp in six.itervalues(knl.temporary_variables): + new_temps[temp.name] = temp.copy( + shape=(nbatches_expr,) + temp.shape, + dim_tags=("c",) * (len(arg.shape) + 1)) + + knl = knl.copy(temporary_variables=new_temps) + else: + import loopy as lp + from loopy.kernel.data import ForceSequentialTag + knl = lp.tag_inames(knl, [(batch_iname, ForceSequentialTag())]) rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, vng) bvc = _BatchVariableChanger(rule_mapping_context, - knl, batch_varying_args, batch_iname_expr) + knl, batch_varying_args, batch_iname_expr, + sequential=sequential) return rule_mapping_context.finish_kernel( bvc.map_kernel(knl)) - # }}} # vim: foldmethod=marker