From c017b0f2f5367c23086da9dac5406ec3b3af3a4d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 3 Mar 2016 21:20:03 -0600 Subject: [PATCH] Add 'sequential' flag to to_batched --- loopy/transform/batch.py | 42 ++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py index 1dc54f94b..967e14de6 100644 --- a/loopy/transform/batch.py +++ b/loopy/transform/batch.py @@ -38,16 +38,18 @@ __doc__ = """ class _BatchVariableChanger(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, kernel, batch_varying_args, - batch_iname_expr): + batch_iname_expr, sequential): super(_BatchVariableChanger, self).__init__(rule_mapping_context) self.kernel = kernel self.batch_varying_args = batch_varying_args self.batch_iname_expr = batch_iname_expr + self.sequential = sequential def needs_batch_subscript(self, name): return ( - name in self.kernel.temporary_variables + (not self.sequential + and name in self.kernel.temporary_variables) or name in self.batch_varying_args) @@ -68,14 +70,18 @@ class _BatchVariableChanger(RuleAwareIdentityMapper): return expr.aggregate[self.batch_iname_expr] -def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"): +def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", + sequential=False): """Takes in a kernel that carries out an operation and returns a kernel that carries out a batch of these operations. :arg nbatches: the number of batches. May be a constant non-negative integer or a string, which will be added as an integer argument. - :arg batch_varying_args: a list of argument names that depend vary per-batch. + :arg batch_varying_args: a list of argument names that vary per-batch. Each such variable will have a batch index added. + :arg sequential: A :class:`bool`. If *True*, do not duplicate + temporary variables for each batch. This automatically tags the batch + iname for sequential execution. """ from pymbolic import var @@ -114,26 +120,32 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"): new_args.append(arg) - new_temps = {} - - for temp in six.itervalues(knl.temporary_variables): - new_temps[temp.name] = temp.copy( - shape=(nbatches_expr,) + temp.shape, - dim_tags=("c",) * (len(arg.shape) + 1)) - knl = knl.copy( domains=new_domains, - args=new_args, - temporary_variables=new_temps) + args=new_args) + + if not sequential: + new_temps = {} + + for temp in six.itervalues(knl.temporary_variables): + new_temps[temp.name] = temp.copy( + shape=(nbatches_expr,) + temp.shape, + dim_tags=("c",) * (len(arg.shape) + 1)) + + knl = knl.copy(temporary_variables=new_temps) + else: + import loopy as lp + from loopy.kernel.data import ForceSequentialTag + knl = lp.tag_inames(knl, [(batch_iname, ForceSequentialTag())]) rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, vng) bvc = _BatchVariableChanger(rule_mapping_context, - knl, batch_varying_args, batch_iname_expr) + knl, batch_varying_args, batch_iname_expr, + sequential=sequential) return rule_mapping_context.finish_kernel( bvc.map_kernel(knl)) - # }}} # vim: foldmethod=marker -- GitLab