From 0c541e46acd025d53a9506b42e867aee056ecc62 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Tue, 9 Jan 2018 19:22:26 -0600
Subject: [PATCH] Changed documentations and attempt to shorten the code.

---
 loopy/transform/batch.py | 47 ++++++++++++++++++++++++++--------------
 1 file changed, 31 insertions(+), 16 deletions(-)

diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py
index 6dbb03b7b..d02c0fc35 100644
--- a/loopy/transform/batch.py
+++ b/loopy/transform/batch.py
@@ -38,6 +38,20 @@ __doc__ = """
 
 # {{{ to_batched
 
+def temp_needs_batching_if_not_sequential(tv, batch_varying_args):
+    from loopy.kernel.data import temp_var_scope
+    if tv.name in batch_varying_args:
+        return True
+    if tv.initializer is not None and tv.read_only:
+        # do not batch read_only temps  if not in
+        # `batch_varying_args`
+        return False
+    if tv.scope == temp_var_scope.PRIVATE:
+        # do not batch private temps if not in `batch_varying args`
+        return False
+    return True
+
+
 class _BatchVariableChanger(RuleAwareIdentityMapper):
     def __init__(self, rule_mapping_context, kernel, batch_varying_args,
             batch_iname_expr, sequential):
@@ -50,16 +64,17 @@ class _BatchVariableChanger(RuleAwareIdentityMapper):
 
     def needs_batch_subscript(self, name):
         tv = self.kernel.temporary_variables.get(name)
-        from loopy.kernel.data import temp_var_scope
-        return (
-                (not self.sequential
-                    and (tv is not None
-                        and not ((
-                            tv.initializer is not None
-                            and tv.read_only) or (
-                                tv.scope == temp_var_scope.PRIVATE))))
-                or
-                name in self.batch_varying_args)
+
+        if name in self.batch_varying_args:
+            return True
+        if not self.sequential:
+            if tv is None:
+                return False
+            if not temp_needs_batching_if_not_sequential(tv,
+                    self.batch_varying_args):
+                return False
+
+        return True
 
     def map_subscript(self, expr, expn_state):
         if not self.needs_batch_subscript(expr.aggregate.name):
@@ -91,6 +106,9 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch",
         sequential=False):
     """Takes in a kernel that carries out an operation and returns a kernel
     that carries out a batch of these operations.
+    ***Note:* For temporaries in a kernel that are private or read only
+    globals, loopy does not does not batch these variables if not mentioned
+    explicitly in `batch_varying_args`.
 
     :arg nbatches: the number of batches. May be a constant non-negative
         integer or a string, which will be added as an integer argument.
@@ -144,18 +162,15 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch",
 
     if not sequential:
         new_temps = {}
-        from loopy.kernel.data import temp_var_scope
 
         for temp in six.itervalues(knl.temporary_variables):
-            if (temp.initializer is not None and temp.read_only) or (
-                    temp.scope == temp_var_scope.PRIVATE and temp.name not in
-                    batch_varying_args):
-                new_temps[temp.name] = temp
-            else:
+            if temp_needs_batching_if_not_sequential(temp, batch_varying_args):
                 new_temps[temp.name] = temp.copy(
                         shape=(nbatches_expr,) + temp.shape,
                         dim_tags=("c",) * (len(temp.shape) + 1),
                         dim_names=_add_unique_dim_name("ibatch", temp.dim_names))
+            else:
+                new_temps[temp.name] = temp
 
         knl = knl.copy(temporary_variables=new_temps)
     else:
-- 
GitLab