Skip to content
Snippets Groups Projects
Commit 8e1fcec3 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make realize_reduction usable even types aren't known

parent f2e9cb3d
No related branches found
No related tags found
No related merge requests found
......@@ -466,7 +466,7 @@ def add_default_dependencies(kernel):
# {{{ rewrite reduction to imperative form
def realize_reduction(kernel, insn_id_filter=None):
def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
"""Rewrites reductions into their imperative form. With *insn_id_filter*
specified, operate only on the instruction with an instruction id matching
*insn_id_filter*.
......@@ -494,7 +494,7 @@ def realize_reduction(kernel, insn_id_filter=None):
# {{{ sequential
def map_reduction_seq(expr, rec, multiple_values_ok, arg_dtype,
def map_reduction_seq(expr, rec, nresults, arg_dtype,
reduction_dtypes):
outer_insn_inames = temp_kernel.insn_inames(insn)
ncomp = len(reduction_dtypes)
......@@ -549,11 +549,11 @@ def realize_reduction(kernel, insn_id_filter=None):
new_insn_add_depends_on.add(reduction_insn.id)
if multiple_values_ok:
return acc_vars
else:
if nresults == 1:
assert len(acc_vars) == 1
return acc_vars[0]
else:
return acc_vars
# }}}
......@@ -577,7 +577,7 @@ def realize_reduction(kernel, insn_id_filter=None):
v[iname].lt_set(v[0] + size)).get_basic_sets()
return bs
def map_reduction_local(expr, rec, multiple_values_ok, arg_dtype,
def map_reduction_local(expr, rec, nresults, arg_dtype,
reduction_dtypes):
red_iname, = expr.inames
ncomp = len(reduction_dtypes)
......@@ -654,7 +654,7 @@ def realize_reduction(kernel, insn_id_filter=None):
(outer_insn_inames - frozenset(expr.inames))
| frozenset([red_iname])),
forced_iname_deps_is_final=insn.forced_iname_deps_is_final,
depends_on=frozenset([init_id]),
depends_on=frozenset([init_id]) | insn.depends_on,
no_sync_with=frozenset([init_id]))
generated_insns.append(transfer_insn)
......@@ -715,33 +715,40 @@ def realize_reduction(kernel, insn_id_filter=None):
new_insn_add_depends_on.add(prev_id)
new_insn_add_no_sync_with.add(prev_id)
new_insn_add_forced_iname_deps.add(stage_exec_iname or base_exec_iname)
new_insn_add_stop_iname_dep_propagation.add(
stage_exec_iname or base_exec_iname)
if multiple_values_ok:
return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars]
else:
if nresults == 1:
assert len(acc_vars) == 1
return acc_vars[0][outer_local_iname_vars + (0,)]
else:
return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars]
# }}}
# {{{ seq/par dispatch
def map_reduction(expr, rec, multiple_values_ok=False):
def map_reduction(expr, rec, nresults=1):
# Only expand one level of reduction at a time, going from outermost to
# innermost. Otherwise we get the (iname + insn) dependencies wrong.
try:
arg_dtype = type_inf_mapper(expr.expr)
except DependencyTypeInferenceFailure:
raise LoopyError("failed to determine type of accumulator for "
"reduction '%s'" % expr)
if unknown_types_ok:
arg_dtype = lp.auto
reduction_dtypes = (lp.auto,)*nresults
arg_dtype = arg_dtype.with_target(kernel.target)
else:
raise LoopyError("failed to determine type of accumulator for "
"reduction '%s'" % expr)
else:
arg_dtype = arg_dtype.with_target(kernel.target)
reduction_dtypes = expr.operation.result_dtypes(
kernel, arg_dtype, expr.inames)
reduction_dtypes = tuple(
dt.with_target(kernel.target) for dt in reduction_dtypes)
reduction_dtypes = expr.operation.result_dtypes(
kernel, arg_dtype, expr.inames)
reduction_dtypes = tuple(
dt.with_target(kernel.target) for dt in reduction_dtypes)
outer_insn_inames = temp_kernel.insn_inames(insn)
bad_inames = frozenset(expr.inames) & outer_insn_inames
......@@ -791,10 +798,10 @@ def realize_reduction(kernel, insn_id_filter=None):
if n_sequential:
assert n_local_par == 0
return map_reduction_seq(expr, rec, multiple_values_ok, arg_dtype,
return map_reduction_seq(expr, rec, nresults, arg_dtype,
reduction_dtypes)
elif n_local_par:
return map_reduction_local(expr, rec, multiple_values_ok, arg_dtype,
return map_reduction_local(expr, rec, nresults, arg_dtype,
reduction_dtypes)
else:
from loopy.diagnostic import warn
......@@ -820,6 +827,7 @@ def realize_reduction(kernel, insn_id_filter=None):
new_insn_add_depends_on = set()
new_insn_add_no_sync_with = set()
new_insn_add_forced_iname_deps = set()
new_insn_add_stop_iname_dep_propagation = set()
generated_insns = []
......@@ -830,10 +838,12 @@ def realize_reduction(kernel, insn_id_filter=None):
new_insns.append(insn)
continue
nresults = len(insn.assignees)
# Run reduction expansion.
from loopy.symbolic import Reduction
if isinstance(insn.expression, Reduction):
new_expressions = cb_mapper(insn.expression, multiple_values_ok=True)
if isinstance(insn.expression, Reduction) and nresults > 1:
new_expressions = cb_mapper(insn.expression, nresults=nresults)
else:
new_expressions = (cb_mapper(insn.expression),)
......@@ -848,7 +858,10 @@ def realize_reduction(kernel, insn_id_filter=None):
| frozenset(new_insn_add_no_sync_with),
forced_iname_deps=(
temp_kernel.insn_inames(insn)
| new_insn_add_forced_iname_deps)
| new_insn_add_forced_iname_deps),
stop_iname_dep_propagation=(
insn.stop_iname_dep_propagation
| new_insn_add_stop_iname_dep_propagation),
)
kwargs.pop("id")
......@@ -876,7 +889,8 @@ def realize_reduction(kernel, insn_id_filter=None):
temp_kernel = kernel.copy(
instructions=new_insns + insn_queue,
temporary_variables=new_temporary_variables)
temporary_variables=new_temporary_variables,
domains=domains)
else:
# nothing happened, we're done with insn
......@@ -1080,7 +1094,7 @@ def preprocess_kernel(kernel, device=None):
# because it manipulates the depends_on field, which could prevent
# defaults from being applied.
kernel = realize_reduction(kernel)
kernel = realize_reduction(kernel, unknown_types_ok=False)
# Ordering restriction:
# add_axes_to_temporaries_for_ilp because reduction accumulators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment