From 0c541e46acd025d53a9506b42e867aee056ecc62 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 9 Jan 2018 19:22:26 -0600 Subject: [PATCH] Changed documentations and attempt to shorten the code. --- loopy/transform/batch.py | 47 ++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py index 6dbb03b7b..d02c0fc35 100644 --- a/loopy/transform/batch.py +++ b/loopy/transform/batch.py @@ -38,6 +38,20 @@ __doc__ = """ # {{{ to_batched +def temp_needs_batching_if_not_sequential(tv, batch_varying_args): + from loopy.kernel.data import temp_var_scope + if tv.name in batch_varying_args: + return True + if tv.initializer is not None and tv.read_only: + # do not batch read_only temps if not in + # `batch_varying_args` + return False + if tv.scope == temp_var_scope.PRIVATE: + # do not batch private temps if not in `batch_varying args` + return False + return True + + class _BatchVariableChanger(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, kernel, batch_varying_args, batch_iname_expr, sequential): @@ -50,16 +64,17 @@ class _BatchVariableChanger(RuleAwareIdentityMapper): def needs_batch_subscript(self, name): tv = self.kernel.temporary_variables.get(name) - from loopy.kernel.data import temp_var_scope - return ( - (not self.sequential - and (tv is not None - and not (( - tv.initializer is not None - and tv.read_only) or ( - tv.scope == temp_var_scope.PRIVATE)))) - or - name in self.batch_varying_args) + + if name in self.batch_varying_args: + return True + if not self.sequential: + if tv is None: + return False + if not temp_needs_batching_if_not_sequential(tv, + self.batch_varying_args): + return False + + return True def map_subscript(self, expr, expn_state): if not self.needs_batch_subscript(expr.aggregate.name): @@ -91,6 +106,9 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", sequential=False): """Takes in a kernel that carries out an operation and returns a kernel that carries out a batch of these operations. + ***Note:* For temporaries in a kernel that are private or read only + globals, loopy does not does not batch these variables if not mentioned + explicitly in `batch_varying_args`. :arg nbatches: the number of batches. May be a constant non-negative integer or a string, which will be added as an integer argument. @@ -144,18 +162,15 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", if not sequential: new_temps = {} - from loopy.kernel.data import temp_var_scope for temp in six.itervalues(knl.temporary_variables): - if (temp.initializer is not None and temp.read_only) or ( - temp.scope == temp_var_scope.PRIVATE and temp.name not in - batch_varying_args): - new_temps[temp.name] = temp - else: + if temp_needs_batching_if_not_sequential(temp, batch_varying_args): new_temps[temp.name] = temp.copy( shape=(nbatches_expr,) + temp.shape, dim_tags=("c",) * (len(temp.shape) + 1), dim_names=_add_unique_dim_name("ibatch", temp.dim_names)) + else: + new_temps[temp.name] = temp knl = knl.copy(temporary_variables=new_temps) else: -- GitLab