From 8e1fcec30f3c8294f343a3ca4aad68da38bdae34 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 26 May 2016 02:17:12 +0200 Subject: [PATCH] Make realize_reduction usable even types aren't known --- loopy/preprocess.py | 66 +++++++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 7b94d1ea9..12dc2973e 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -466,7 +466,7 @@ def add_default_dependencies(kernel): # {{{ rewrite reduction to imperative form -def realize_reduction(kernel, insn_id_filter=None): +def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): """Rewrites reductions into their imperative form. With *insn_id_filter* specified, operate only on the instruction with an instruction id matching *insn_id_filter*. @@ -494,7 +494,7 @@ def realize_reduction(kernel, insn_id_filter=None): # {{{ sequential - def map_reduction_seq(expr, rec, multiple_values_ok, arg_dtype, + def map_reduction_seq(expr, rec, nresults, arg_dtype, reduction_dtypes): outer_insn_inames = temp_kernel.insn_inames(insn) ncomp = len(reduction_dtypes) @@ -549,11 +549,11 @@ def realize_reduction(kernel, insn_id_filter=None): new_insn_add_depends_on.add(reduction_insn.id) - if multiple_values_ok: - return acc_vars - else: + if nresults == 1: assert len(acc_vars) == 1 return acc_vars[0] + else: + return acc_vars # }}} @@ -577,7 +577,7 @@ def realize_reduction(kernel, insn_id_filter=None): v[iname].lt_set(v[0] + size)).get_basic_sets() return bs - def map_reduction_local(expr, rec, multiple_values_ok, arg_dtype, + def map_reduction_local(expr, rec, nresults, arg_dtype, reduction_dtypes): red_iname, = expr.inames ncomp = len(reduction_dtypes) @@ -654,7 +654,7 @@ def realize_reduction(kernel, insn_id_filter=None): (outer_insn_inames - frozenset(expr.inames)) | frozenset([red_iname])), forced_iname_deps_is_final=insn.forced_iname_deps_is_final, - depends_on=frozenset([init_id]), + depends_on=frozenset([init_id]) | insn.depends_on, no_sync_with=frozenset([init_id])) generated_insns.append(transfer_insn) @@ -715,33 +715,40 @@ def realize_reduction(kernel, insn_id_filter=None): new_insn_add_depends_on.add(prev_id) new_insn_add_no_sync_with.add(prev_id) new_insn_add_forced_iname_deps.add(stage_exec_iname or base_exec_iname) + new_insn_add_stop_iname_dep_propagation.add( + stage_exec_iname or base_exec_iname) - if multiple_values_ok: - return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars] - else: + if nresults == 1: assert len(acc_vars) == 1 return acc_vars[0][outer_local_iname_vars + (0,)] - + else: + return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars] # }}} # {{{ seq/par dispatch - def map_reduction(expr, rec, multiple_values_ok=False): + def map_reduction(expr, rec, nresults=1): # Only expand one level of reduction at a time, going from outermost to # innermost. Otherwise we get the (iname + insn) dependencies wrong. try: arg_dtype = type_inf_mapper(expr.expr) except DependencyTypeInferenceFailure: - raise LoopyError("failed to determine type of accumulator for " - "reduction '%s'" % expr) + if unknown_types_ok: + arg_dtype = lp.auto + + reduction_dtypes = (lp.auto,)*nresults - arg_dtype = arg_dtype.with_target(kernel.target) + else: + raise LoopyError("failed to determine type of accumulator for " + "reduction '%s'" % expr) + else: + arg_dtype = arg_dtype.with_target(kernel.target) - reduction_dtypes = expr.operation.result_dtypes( - kernel, arg_dtype, expr.inames) - reduction_dtypes = tuple( - dt.with_target(kernel.target) for dt in reduction_dtypes) + reduction_dtypes = expr.operation.result_dtypes( + kernel, arg_dtype, expr.inames) + reduction_dtypes = tuple( + dt.with_target(kernel.target) for dt in reduction_dtypes) outer_insn_inames = temp_kernel.insn_inames(insn) bad_inames = frozenset(expr.inames) & outer_insn_inames @@ -791,10 +798,10 @@ def realize_reduction(kernel, insn_id_filter=None): if n_sequential: assert n_local_par == 0 - return map_reduction_seq(expr, rec, multiple_values_ok, arg_dtype, + return map_reduction_seq(expr, rec, nresults, arg_dtype, reduction_dtypes) elif n_local_par: - return map_reduction_local(expr, rec, multiple_values_ok, arg_dtype, + return map_reduction_local(expr, rec, nresults, arg_dtype, reduction_dtypes) else: from loopy.diagnostic import warn @@ -820,6 +827,7 @@ def realize_reduction(kernel, insn_id_filter=None): new_insn_add_depends_on = set() new_insn_add_no_sync_with = set() new_insn_add_forced_iname_deps = set() + new_insn_add_stop_iname_dep_propagation = set() generated_insns = [] @@ -830,10 +838,12 @@ def realize_reduction(kernel, insn_id_filter=None): new_insns.append(insn) continue + nresults = len(insn.assignees) + # Run reduction expansion. from loopy.symbolic import Reduction - if isinstance(insn.expression, Reduction): - new_expressions = cb_mapper(insn.expression, multiple_values_ok=True) + if isinstance(insn.expression, Reduction) and nresults > 1: + new_expressions = cb_mapper(insn.expression, nresults=nresults) else: new_expressions = (cb_mapper(insn.expression),) @@ -848,7 +858,10 @@ def realize_reduction(kernel, insn_id_filter=None): | frozenset(new_insn_add_no_sync_with), forced_iname_deps=( temp_kernel.insn_inames(insn) - | new_insn_add_forced_iname_deps) + | new_insn_add_forced_iname_deps), + stop_iname_dep_propagation=( + insn.stop_iname_dep_propagation + | new_insn_add_stop_iname_dep_propagation), ) kwargs.pop("id") @@ -876,7 +889,8 @@ def realize_reduction(kernel, insn_id_filter=None): temp_kernel = kernel.copy( instructions=new_insns + insn_queue, - temporary_variables=new_temporary_variables) + temporary_variables=new_temporary_variables, + domains=domains) else: # nothing happened, we're done with insn @@ -1080,7 +1094,7 @@ def preprocess_kernel(kernel, device=None): # because it manipulates the depends_on field, which could prevent # defaults from being applied. - kernel = realize_reduction(kernel) + kernel = realize_reduction(kernel, unknown_types_ok=False) # Ordering restriction: # add_axes_to_temporaries_for_ilp because reduction accumulators -- GitLab