From 8e1fcec30f3c8294f343a3ca4aad68da38bdae34 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 26 May 2016 02:17:12 +0200
Subject: [PATCH] Make realize_reduction usable even types aren't known

---
 loopy/preprocess.py | 66 +++++++++++++++++++++++++++------------------
 1 file changed, 40 insertions(+), 26 deletions(-)

diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 7b94d1ea9..12dc2973e 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -466,7 +466,7 @@ def add_default_dependencies(kernel):
 
 # {{{ rewrite reduction to imperative form
 
-def realize_reduction(kernel, insn_id_filter=None):
+def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
     """Rewrites reductions into their imperative form. With *insn_id_filter*
     specified, operate only on the instruction with an instruction id matching
     *insn_id_filter*.
@@ -494,7 +494,7 @@ def realize_reduction(kernel, insn_id_filter=None):
 
     # {{{ sequential
 
-    def map_reduction_seq(expr, rec, multiple_values_ok, arg_dtype,
+    def map_reduction_seq(expr, rec, nresults, arg_dtype,
             reduction_dtypes):
         outer_insn_inames = temp_kernel.insn_inames(insn)
         ncomp = len(reduction_dtypes)
@@ -549,11 +549,11 @@ def realize_reduction(kernel, insn_id_filter=None):
 
         new_insn_add_depends_on.add(reduction_insn.id)
 
-        if multiple_values_ok:
-            return acc_vars
-        else:
+        if nresults == 1:
             assert len(acc_vars) == 1
             return acc_vars[0]
+        else:
+            return acc_vars
 
     # }}}
 
@@ -577,7 +577,7 @@ def realize_reduction(kernel, insn_id_filter=None):
                 v[iname].lt_set(v[0] + size)).get_basic_sets()
         return bs
 
-    def map_reduction_local(expr, rec, multiple_values_ok, arg_dtype,
+    def map_reduction_local(expr, rec, nresults, arg_dtype,
             reduction_dtypes):
         red_iname, = expr.inames
         ncomp = len(reduction_dtypes)
@@ -654,7 +654,7 @@ def realize_reduction(kernel, insn_id_filter=None):
                     (outer_insn_inames - frozenset(expr.inames))
                     | frozenset([red_iname])),
                 forced_iname_deps_is_final=insn.forced_iname_deps_is_final,
-                depends_on=frozenset([init_id]),
+                depends_on=frozenset([init_id]) | insn.depends_on,
                 no_sync_with=frozenset([init_id]))
         generated_insns.append(transfer_insn)
 
@@ -715,33 +715,40 @@ def realize_reduction(kernel, insn_id_filter=None):
         new_insn_add_depends_on.add(prev_id)
         new_insn_add_no_sync_with.add(prev_id)
         new_insn_add_forced_iname_deps.add(stage_exec_iname or base_exec_iname)
+        new_insn_add_stop_iname_dep_propagation.add(
+                stage_exec_iname or base_exec_iname)
 
-        if multiple_values_ok:
-            return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars]
-        else:
+        if nresults == 1:
             assert len(acc_vars) == 1
             return acc_vars[0][outer_local_iname_vars + (0,)]
-
+        else:
+            return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars]
     # }}}
 
     # {{{ seq/par dispatch
 
-    def map_reduction(expr, rec, multiple_values_ok=False):
+    def map_reduction(expr, rec, nresults=1):
         # Only expand one level of reduction at a time, going from outermost to
         # innermost. Otherwise we get the (iname + insn) dependencies wrong.
 
         try:
             arg_dtype = type_inf_mapper(expr.expr)
         except DependencyTypeInferenceFailure:
-            raise LoopyError("failed to determine type of accumulator for "
-                    "reduction '%s'" % expr)
+            if unknown_types_ok:
+                arg_dtype = lp.auto
+
+                reduction_dtypes = (lp.auto,)*nresults
 
-        arg_dtype = arg_dtype.with_target(kernel.target)
+            else:
+                raise LoopyError("failed to determine type of accumulator for "
+                        "reduction '%s'" % expr)
+        else:
+            arg_dtype = arg_dtype.with_target(kernel.target)
 
-        reduction_dtypes = expr.operation.result_dtypes(
-                    kernel, arg_dtype, expr.inames)
-        reduction_dtypes = tuple(
-                dt.with_target(kernel.target) for dt in reduction_dtypes)
+            reduction_dtypes = expr.operation.result_dtypes(
+                        kernel, arg_dtype, expr.inames)
+            reduction_dtypes = tuple(
+                    dt.with_target(kernel.target) for dt in reduction_dtypes)
 
         outer_insn_inames = temp_kernel.insn_inames(insn)
         bad_inames = frozenset(expr.inames) & outer_insn_inames
@@ -791,10 +798,10 @@ def realize_reduction(kernel, insn_id_filter=None):
 
         if n_sequential:
             assert n_local_par == 0
-            return map_reduction_seq(expr, rec, multiple_values_ok, arg_dtype,
+            return map_reduction_seq(expr, rec, nresults, arg_dtype,
                     reduction_dtypes)
         elif n_local_par:
-            return map_reduction_local(expr, rec, multiple_values_ok, arg_dtype,
+            return map_reduction_local(expr, rec, nresults, arg_dtype,
                     reduction_dtypes)
         else:
             from loopy.diagnostic import warn
@@ -820,6 +827,7 @@ def realize_reduction(kernel, insn_id_filter=None):
         new_insn_add_depends_on = set()
         new_insn_add_no_sync_with = set()
         new_insn_add_forced_iname_deps = set()
+        new_insn_add_stop_iname_dep_propagation = set()
 
         generated_insns = []
 
@@ -830,10 +838,12 @@ def realize_reduction(kernel, insn_id_filter=None):
             new_insns.append(insn)
             continue
 
+        nresults = len(insn.assignees)
+
         # Run reduction expansion.
         from loopy.symbolic import Reduction
-        if isinstance(insn.expression, Reduction):
-            new_expressions = cb_mapper(insn.expression, multiple_values_ok=True)
+        if isinstance(insn.expression, Reduction) and nresults > 1:
+            new_expressions = cb_mapper(insn.expression, nresults=nresults)
         else:
             new_expressions = (cb_mapper(insn.expression),)
 
@@ -848,7 +858,10 @@ def realize_reduction(kernel, insn_id_filter=None):
                     | frozenset(new_insn_add_no_sync_with),
                     forced_iname_deps=(
                         temp_kernel.insn_inames(insn)
-                        | new_insn_add_forced_iname_deps)
+                        | new_insn_add_forced_iname_deps),
+                    stop_iname_dep_propagation=(
+                        insn.stop_iname_dep_propagation
+                        | new_insn_add_stop_iname_dep_propagation),
                     )
 
             kwargs.pop("id")
@@ -876,7 +889,8 @@ def realize_reduction(kernel, insn_id_filter=None):
 
             temp_kernel = kernel.copy(
                     instructions=new_insns + insn_queue,
-                    temporary_variables=new_temporary_variables)
+                    temporary_variables=new_temporary_variables,
+                    domains=domains)
 
         else:
             # nothing happened, we're done with insn
@@ -1080,7 +1094,7 @@ def preprocess_kernel(kernel, device=None):
     #   because it manipulates the depends_on field, which could prevent
     #   defaults from being applied.
 
-    kernel = realize_reduction(kernel)
+    kernel = realize_reduction(kernel, unknown_types_ok=False)
 
     # Ordering restriction:
     # add_axes_to_temporaries_for_ilp because reduction accumulators
-- 
GitLab