From 92dc6fbec92997d0b77f31fe0a6a472c83c31e8b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 22 Jun 2015 02:17:53 -0500
Subject: [PATCH] Allow reuse of temporaries in precompute

---
 loopy/precompute.py | 95 +++++++++++++++++++++++++++++++++++++--------
 1 file changed, 78 insertions(+), 17 deletions(-)

diff --git a/loopy/precompute.py b/loopy/precompute.py
index 935d6d440..726cc0786 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -132,7 +132,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
             access_descriptors, array_base_map,
             storage_axis_names, storage_axis_sources,
             non1_storage_axis_names,
-            target_var_name):
+            temporary_name):
         super(RuleInvocationReplacer, self).__init__(rule_mapping_context)
 
         self.subst_name = subst_name
@@ -146,7 +146,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
         self.storage_axis_sources = storage_axis_sources
         self.non1_storage_axis_names = non1_storage_axis_names
 
-        self.target_var_name = target_var_name
+        self.temporary_name = temporary_name
 
     def map_substitution(self, name, tag, arguments, expn_state):
         if not (
@@ -196,7 +196,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
             ax_index = simplify_via_aff(ax_index - sax_base_idx)
             stor_subscript.append(ax_index)
 
-        new_outer_expr = var(self.target_var_name)
+        new_outer_expr = var(self.temporary_name)
         if stor_subscript:
             new_outer_expr = new_outer_expr.index(tuple(stor_subscript))
 
@@ -210,9 +210,9 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
 
 
 def precompute(kernel, subst_use, sweep_inames=[], within=None,
-        storage_axes=None, precompute_inames=None, storage_axis_to_tag={},
-        default_tag="l.auto", dtype=None, fetch_bounding_box=False,
-        temporary_is_local=None):
+        storage_axes=None, temporary_name=None, precompute_inames=None,
+        storage_axis_to_tag={}, default_tag="l.auto", dtype=None,
+        fetch_bounding_box=False, temporary_is_local=None):
     """Precompute the expression described in the substitution rule determined by
     *subst_use* and store it in a temporary array. A precomputation needs two
     things to operate, a list of *sweep_inames* (order irrelevant) and an
@@ -263,6 +263,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
         May also equivalently be a comma-separated string.
     :arg within: a stack match as understood by
         :func:`loopy.context_matching.parse_stack_match`.
+    :arg temporary_name:
+        The temporary variable name to use for storing the precomputed data.
+        If it does not exist, it will be created. If it does exist, its properties
+        (such as size, type) are checked (and updated, if possible) to match
+        its use.
     :arg precompute_inames:
         If the specified inames do not already exist, they will be
         created. If they do already exist, their loop domain is verified
@@ -584,8 +589,10 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
     # {{{ set up compute insn
 
-    target_var_name = var_name_gen(based_on=c_subst_name)
-    assignee = var(target_var_name)
+    if temporary_name is None:
+        temporary_name = var_name_gen(based_on=c_subst_name)
+
+    assignee = var(temporary_name)
 
     if non1_storage_axis_names:
         assignee = assignee.index(
@@ -633,7 +640,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
             access_descriptors, abm,
             storage_axis_names, storage_axis_sources,
             non1_storage_axis_names,
-            target_var_name)
+            temporary_name)
 
     kernel = invr.map_kernel(kernel)
     kernel = kernel.copy(
@@ -655,15 +662,69 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     if temporary_is_local is None:
         temporary_is_local = lp.auto
 
+    new_temp_shape = tuple(abm.non1_storage_shape)
+
     new_temporary_variables = kernel.temporary_variables.copy()
-    temp_var = lp.TemporaryVariable(
-            name=target_var_name,
-            dtype=dtype,
-            base_indices=(0,)*len(abm.non1_storage_shape),
-            shape=tuple(abm.non1_storage_shape),
-            is_local=temporary_is_local)
-
-    new_temporary_variables[target_var_name] = temp_var
+    if temporary_name not in new_temporary_variables:
+        temp_var = lp.TemporaryVariable(
+                name=temporary_name,
+                dtype=dtype,
+                base_indices=(0,)*len(new_temp_shape),
+                shape=tuple(abm.non1_storage_shape),
+                is_local=temporary_is_local)
+
+    else:
+        temp_var = new_temporary_variables[temporary_name]
+
+        # {{{ check and adapt existing temporary
+
+        if temp_var.dtype is lp.auto:
+            pass
+        elif temp_var.dtype is not lp.auto and dtype is lp.auto:
+            dtype = temp_var.dtype
+        elif temp_var.dtype is not lp.auto and dtype is not lp.auto:
+            if temp_var.dtype != dtype:
+                raise LoopyError("Existing and new dtype of temporary '%s' "
+                        "do not match (existing: %s, new: %s)"
+                        % (temporary_name, temp_var.dtype, dtype))
+
+        temp_var = temp_var.copy(dtype=dtype)
+
+        if len(temp_var.shape) != len(new_temp_shape):
+            raise LoopyError("Existing and new temporary '%s' do not "
+                    "have matching number of dimensions "
+                    % (temporary_name,
+                        len(temp_var.shape), len(new_temp_shape)))
+
+        if temp_var.base_indices != (0,) * len(new_temp_shape):
+            raise LoopyError("Existing and new temporary '%s' do not "
+                    "have matching number of dimensions "
+                    % (temporary_name,
+                        len(temp_var.shape), len(new_temp_shape)))
+
+        new_temp_shape = tuple(
+                max(i, ex_i)
+                for i, ex_i in zip(new_temp_shape, temp_var.shape))
+
+        temp_var = temp_var.copy(shape=new_temp_shape)
+
+        if temporary_is_local == temp_var.is_local:
+            pass
+        elif temporary_is_local is lp.auto:
+            temporary_is_local = temp_var.is_local
+        elif temp_var.is_local is lp.auto:
+            pass
+        else:
+            raise LoopyError("Existing and new temporary '%s' do not "
+                    "have matching values of 'is_local'"
+                    % (temporary_name,
+                        temp_var.is_local, temporary_is_local))
+
+        temp_var = temp_var.copy(is_local=temporary_is_local)
+
+        # }}}
+
+    new_temporary_variables[temporary_name] = temp_var
 
     kernel = kernel.copy(
             temporary_variables=new_temporary_variables)
-- 
GitLab