From f5f113cf06c39ed7d66eeba0fa46950ce000b52f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 11 Aug 2015 23:53:17 -0500
Subject: [PATCH] Add dependencies on compute insn in precompute()

---
 loopy/precompute.py | 33 ++++++++++++++++++++++++++-------
 1 file changed, 26 insertions(+), 7 deletions(-)

diff --git a/loopy/precompute.py b/loopy/precompute.py
index b1df5f678..ee7f815cf 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -136,7 +136,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
             access_descriptors, array_base_map,
             storage_axis_names, storage_axis_sources,
             non1_storage_axis_names,
-            temporary_name):
+            temporary_name, compute_insn_id):
         super(RuleInvocationReplacer, self).__init__(rule_mapping_context)
 
         self.subst_name = subst_name
@@ -151,6 +151,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
         self.non1_storage_axis_names = non1_storage_axis_names
 
         self.temporary_name = temporary_name
+        self.compute_insn_id = compute_insn_id
 
     def map_substitution(self, name, tag, arguments, expn_state):
         if not (
@@ -211,8 +212,26 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
         # further as compute expression has already been seen
         # by rule_mapping_context.
 
+        self.replaced_something = True
+
         return new_outer_expr
 
+    def map_kernel(self, kernel):
+        new_insns = []
+
+        for insn in kernel.instructions:
+            self.replaced_something = False
+
+            insn = insn.with_transformed_expressions(self, kernel, insn)
+
+            if self.replaced_something:
+                insn = insn.copy(
+                        insn_deps=insn.insn_deps | frozenset([self.compute_insn_id]))
+
+            new_insns.append(insn)
+
+        return kernel.copy(instructions=new_insns)
+
 # }}}
 
 
@@ -220,7 +239,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
         storage_axes=None, temporary_name=None, precompute_inames=None,
         storage_axis_to_tag={}, default_tag="l.auto", dtype=None,
         fetch_bounding_box=False, temporary_is_local=None,
-        insn_id=None):
+        compute_insn_id=None):
     """Precompute the expression described in the substitution rule determined by
     *subst_use* and store it in a temporary array. A precomputation needs two
     things to operate, a list of *sweep_inames* (order irrelevant) and an
@@ -280,7 +299,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
         If the specified inames do not already exist, they will be
         created. If they do already exist, their loop domain is verified
         against the one required for this precomputation.
-    :arg insn_id: The ID of the instruction performing the precomputation.
+    :arg compute_insn_id: The ID of the instruction performing the precomputation.
 
     If `storage_axes` is not specified, it defaults to the arrangement
     `<direct sweep axes><arguments>` with the direct sweep axes being the
@@ -686,11 +705,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     # }}}
 
     from loopy.kernel.data import ExpressionInstruction
-    if insn_id is None:
-        insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name)
+    if compute_insn_id is None:
+        compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name)
 
     compute_insn = ExpressionInstruction(
-            id=insn_id,
+            id=compute_insn_id,
             assignee=assignee,
             expression=compute_expression)
 
@@ -703,7 +722,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
             access_descriptors, abm,
             storage_axis_names, storage_axis_sources,
             non1_storage_axis_names,
-            temporary_name)
+            temporary_name, compute_insn_id)
 
     kernel = invr.map_kernel(kernel)
     kernel = kernel.copy(
-- 
GitLab