Skip to content
Snippets Groups Projects
Commit 69bd5756 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fixes for the subst rule/precompute rewrite.

parent abbded75
No related branches found
No related tags found
No related merge requests found
......@@ -428,7 +428,7 @@ def add_prefetch(kernel, var_name, sweep_dims, dim_args=None,
newly_created_vars = set()
parameters = []
for i in range(len(arg.shape)):
based_on = "%s_i%d" % (var_name, i)
based_on = "%s_fetch_%d" % (var_name, i)
if dim_args is not None and i < len(dim_args):
based_on = dim_args[i]
......
......@@ -58,7 +58,7 @@ def check_for_double_use_of_hw_axes(kernel):
if isinstance(tag, UniqueTag):
key = tag.key
if key in insn_tag_keys:
raise RuntimeError("instruction '%s' has two "
raise RuntimeError("instruction '%s' has multiple "
"inames tagged '%s'" % (insn.id, tag))
insn_tag_keys.add(key)
......@@ -155,7 +155,7 @@ def check_for_write_races(kernel):
raise RuntimeError(
"instruction '%s' contains a write race: "
"instruction will be run across parallel iname(s) '%s', which "
"is/are not referenced in the assignee index"
"is/are not referenced in the lhs index"
% (insn.id, ",".join(inames_without_write_dep)))
def check_for_orphaned_user_hardware_axes(kernel):
......
......@@ -36,7 +36,7 @@ def to_parameters_or_project_out(param_inames, set_inames, set):
def get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
def get_footprint(kernel, subst_name, old_arg_names, arg_names,
sweep_inames, invocation_descriptors):
global_footprint_map = None
......@@ -45,8 +45,8 @@ def get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
for invdesc in invocation_descriptors:
for iname in sweep_inames:
if iname in arg_names:
arg_idx = arg_names.index(iname)
if iname in old_arg_names:
arg_idx = old_arg_names.index(iname)
processed_sweep_inames.add(
get_dependencies(invdesc.args[arg_idx]))
else:
......@@ -55,11 +55,11 @@ def get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
# {{{ construct, check mapping
map_space = kernel.space
ln = len(unique_new_arg_names)
ln = len(arg_names)
rn = kernel.space.dim(dim_type.out)
map_space = map_space.add_dims(dim_type.in_, ln)
for i, iname in enumerate(unique_new_arg_names):
for i, iname in enumerate(arg_names):
map_space = map_space.set_dim_name(dim_type.in_, i, iname+"'")
set_space = map_space.move_dims(
......@@ -69,7 +69,7 @@ def get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
footprint_map = None
from loopy.symbolic import aff_from_expr
for uarg_name, arg_val in zip(unique_new_arg_names, invdesc.args):
for uarg_name, arg_val in zip(arg_names, invdesc.args):
cns = isl.Constraint.equality_from_aff(
aff_from_expr(set_space, var(uarg_name+"'") - arg_val))
......@@ -92,7 +92,8 @@ def get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
processed_sweep_inames = list(processed_sweep_inames)
global_footprint_map = global_footprint_map.intersect_range(kernel.domain)
global_footprint_map = (isl.Map.from_basic_map(global_footprint_map)
.intersect_range(kernel.domain))
# move non-sweep-dimensions into parameter space
sweep_footprint_map = global_footprint_map.coalesce()
......@@ -114,51 +115,72 @@ def get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
% subst_name)
from loopy.kernel import find_var_base_indices_and_shape_from_inames
base_indices, shape = find_var_base_indices_and_shape_from_inames(
sfm_dom, [uarg+"'" for uarg in unique_new_arg_names],
arg_base_indices, shape = find_var_base_indices_and_shape_from_inames(
sfm_dom, [uarg+"'" for uarg in arg_names],
kernel.cache_manager)
print arg_names, shape
# compute augmented domain
# {{{ filter out unit-length dimensions
non1_arg_names = []
non1_arg_base_indices = []
non1_shape = []
for arg_name, bi, l in zip(arg_names, arg_base_indices, shape):
if l > 1:
non1_arg_names.append(arg_name)
non1_arg_base_indices.append(bi)
non1_shape.append(l)
# }}}
# {{{ subtract off the base indices
# add the new, base-0 as new in dimensions
sp = global_footprint_map.get_space()
tgt_idx = sp.dim(dim_type.out)
n_args = len(unique_new_arg_names)
n_args = len(arg_names)
nn1_args = len(non1_arg_names)
aug_domain = global_footprint_map.move_dims(
dim_type.out, tgt_idx,
dim_type.in_, 0,
n_args).range().coalesce()
aug_domain = aug_domain.insert_dims(dim_type.set, tgt_idx, n_args)
for i, name in enumerate(unique_new_arg_names):
aug_domain = aug_domain.insert_dims(dim_type.set, tgt_idx, nn1_args)
for i, name in enumerate(non1_arg_names):
aug_domain = aug_domain.set_dim_name(dim_type.set, tgt_idx+i, name)
# index layout now:
# <....out.....> (tgt_idx) <base-0 args> <args>
# <....out.....> (tgt_idx) <base-0 non-1-length args> <args>
from loopy.symbolic import aff_from_expr
for uarg_name, bi in zip(unique_new_arg_names, base_indices):
cns = isl.Constraint.equality_from_aff(
aff_from_expr(aug_domain.get_space(),
var(uarg_name) - (var(uarg_name+"'") - bi)))
for arg_name, bi, s in zip(arg_names, arg_base_indices, shape):
if s > 1:
cns = isl.Constraint.equality_from_aff(
aff_from_expr(aug_domain.get_space(),
var(arg_name) - (var(arg_name+"'") - bi)))
aug_domain = aug_domain.add_constraint(cns)
# }}}
aug_domain = aug_domain.add_constraint(cns)
# eliminate inames with non-zero base indices
aug_domain = aug_domain.eliminate(dim_type.set, tgt_idx+n_args, n_args)
aug_domain = aug_domain.remove_dims(dim_type.set, tgt_idx+n_args, n_args)
aug_domain = aug_domain.eliminate(dim_type.set, tgt_idx+nn1_args, n_args)
aug_domain = aug_domain.remove_dims(dim_type.set, tgt_idx+nn1_args, n_args)
base_indices_2, shape_2 = find_var_base_indices_and_shape_from_inames(
aug_domain, unique_new_arg_names,
kernel.cache_manager)
aug_domain, non1_arg_names, kernel.cache_manager)
assert base_indices_2 == [0] * n_args
assert shape_2 == shape
assert base_indices_2 == [0] * nn1_args
assert shape_2 == non1_shape
return aug_domain, base_indices, shape
return (non1_arg_names, aug_domain,
arg_base_indices, non1_arg_base_indices, non1_shape)
......@@ -229,7 +251,7 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
newly_created_var_names = set()
# {{{ make sure that new
# {{{ make sure that new arg names are unique
# (and substitute in subst_expressions if any variable name changes are necessary)
......@@ -252,12 +274,16 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
if new_name is not None:
old_to_new[name] = var(new_name)
newly_created_var_names.add(new_name)
unique_new_arg_names.append(new_name)
new_arg_name_to_tag[new_name] = arg_name_to_tag[name]
newly_created_var_names.add(new_name)
else:
unique_new_arg_names.append(name)
new_arg_name_to_tag[name] = arg_name_to_tag[name]
newly_created_var_names.add(name)
old_arg_names = arg_names
arg_names = unique_new_arg_names
arg_name_to_tag = new_arg_name_to_tag
subst_expr = (
......@@ -269,10 +295,10 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
# {{{ align and intersect the footprint and the domain
# (If there are independent inames, this adds extra dimensions to the domain.)
new_domain, target_var_base_indices, target_var_shape = \
get_footprint(kernel, subst_name, arg_names, unique_new_arg_names,
sweep_inames, invocation_descriptors)
(non1_arg_names, new_domain,
arg_base_indices, non1_arg_base_indices, non1_shape) = \
get_footprint(kernel, subst_name, old_arg_names, arg_names,
sweep_inames, invocation_descriptors)
new_domain = new_domain.coalesce()
if isinstance(new_domain, isl.Set):
......@@ -296,8 +322,8 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
new_temporary_variables[target_var_name] = TemporaryVariable(
name=target_var_name,
dtype=np.dtype(dtype),
base_indices=target_var_base_indices,
shape=target_var_shape,
base_indices=(0,)*len(non1_shape),
shape=non1_shape,
is_local=None)
# }}}
......@@ -306,14 +332,27 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
assignee = var(target_var_name)
if unique_new_arg_names:
assignee = assignee[tuple(var(iname) for iname in unique_new_arg_names)]
if non1_arg_names:
assignee = assignee[tuple(var(iname) for iname in non1_arg_names)]
def zero_length_1_arg(arg_name):
if arg_name in non1_arg_names:
return var(arg_name)
else:
return 0
compute_expr = (SubstitutionMapper(
make_subst_func(dict(
(arg_name, zero_length_1_arg(arg_name)+bi)
for arg_name, bi in zip(arg_names, arg_base_indices)
)))
(subst_expr))
from loopy.kernel import Instruction
compute_insn = Instruction(
id=kernel.make_unique_instruction_id(based_on=subst_name),
assignee=assignee,
expression=subst_expr)
expression=compute_expr)
# }}}
......@@ -330,7 +369,7 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
return
args = [simplify_via_aff(new_domain.get_space(), arg-bi)
for arg, bi in zip(args, target_var_base_indices)]
for arg, bi in zip(args, non1_arg_base_indices)]
new_outer_expr = var(target_var_name)
if args:
......@@ -348,7 +387,8 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[],
# }}}
new_iname_to_tag = kernel.iname_to_tag.copy()
new_iname_to_tag.update(arg_name_to_tag)
if sweep_inames:
new_iname_to_tag.update(arg_name_to_tag)
new_substs = dict(
(s.name, s.copy(expression=sub_map(subst.expression)))
......
......@@ -1112,12 +1112,13 @@ def find_var_base_indices_and_shape_from_inames(domain, inames, cache_manager):
lower_bound_pw_aff = cache_manager.dim_min(domain, iname_to_dim[iname][1])
upper_bound_pw_aff = cache_manager.dim_max(domain, iname_to_dim[iname][1])
from loopy.isl_helpers import static_max_of_pw_aff
from loopy.isl_helpers import static_max_of_pw_aff, static_value_of_pw_aff
from loopy.symbolic import pw_aff_to_expr
shape.append(pw_aff_to_expr(static_max_of_pw_aff(
upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True)))
base_indices.append(pw_aff_to_expr(lower_bound_pw_aff))
base_indices.append(pw_aff_to_expr(
static_value_of_pw_aff(lower_bound_pw_aff, constants_only=False)))
return base_indices, shape
......
......@@ -91,7 +91,6 @@ def extract_subst(kernel, subst_name, template, parameters):
lhs_mapping_candidates=set(parameters) | set(matching_vars))
def gather_exprs(expr, mapper):
print expr
urecs = unif(template, expr)
if urecs:
......
......@@ -85,7 +85,7 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin):
class WalkMapper(WalkMapperBase):
def map_reduction(self, expr):
self.rec(expr.expression)
self.rec(expr.expr)
class CallbackMapper(CallbackMapperBase, IdentityMapper):
map_reduction = CallbackMapperBase.map_constant
......@@ -330,6 +330,9 @@ class LoopyCCodeMapper(CCodeMapper):
var_subst_map.update(assignments)
return self.copy(var_subst_map=var_subst_map)
def map_common_subexpression(self, expr, prec):
raise RuntimeError("common subexpressions are not allowed in loopy")
def map_variable(self, expr, prec):
if expr.name in self.var_subst_map:
if self.with_annotation:
......
......@@ -219,8 +219,10 @@ def test_plain_matrix_mul(ctx_factory):
knl = lp.split_dimension(knl, "j", 16,
outer_tag="g.1", inner_tag="l.0")
knl = lp.split_dimension(knl, "k", 16)
knl = lp.add_prefetch(knl, 'a', ["k_inner", "i_inner"])
knl = lp.add_prefetch(knl, 'b', ["j_inner", "k_inner", ])
knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"])
knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner", ])
print lp.preprocess_kernel(knl)
kernel_gen = lp.generate_loop_schedules(knl)
kernel_gen = lp.check_kernels(kernel_gen, {})
......@@ -319,8 +321,8 @@ def test_rank_one(ctx_factory):
name="rank_one", assumptions="n >= 16")
def variant_1(knl):
knl = lp.add_prefetch(knl, "a")
knl = lp.add_prefetch(knl, "b")
knl = lp.add_prefetch(knl, "a", [])
knl = lp.add_prefetch(knl, "b", [])
return knl
def variant_2(knl):
......@@ -329,8 +331,8 @@ def test_rank_one(ctx_factory):
knl = lp.split_dimension(knl, "j", 16,
outer_tag="g.1", inner_tag="l.1")
knl = lp.add_prefetch(knl, "a")
knl = lp.add_prefetch(knl, "b")
knl = lp.add_prefetch(knl, "a", [])
knl = lp.add_prefetch(knl, "b", [])
return knl
def variant_3(knl):
......@@ -360,15 +362,16 @@ def test_rank_one(ctx_factory):
knl = lp.split_dimension(knl, "j_inner", 16,
inner_tag="l.1")
knl = lp.split_dimension(knl, "j_inner_fetch_b", 16,
knl = lp.split_dimension(knl, "a_fetch_0", 16,
outer_tag="l.1", inner_tag="l.0")
knl = lp.split_dimension(knl, "i_inner_fetch_a", 16,
knl = lp.split_dimension(knl, "b_fetch_0", 16,
outer_tag="l.1", inner_tag="l.0")
return knl
seq_knl = knl
for variant in [variant_1, variant_2, variant_4]:
#for variant in [variant_1, variant_2, variant_4]:
for variant in [variant_2, variant_4]:
kernel_gen = lp.generate_loop_schedules(variant(knl))
kernel_gen = lp.check_kernels(kernel_gen, dict(n=n))
......
......@@ -67,14 +67,13 @@ def test_multi_cse(ctx_factory):
knl = lp.make_kernel(ctx.devices[0],
"{[i]: 0<=i<100}",
[
"[i] <float32> z[i] = cse(a[i]) + cse(a[i])**2"
"[i] <float32> z[i] = a[i] + a[i]**2"
],
[lp.ArrayArg("a", np.float32, shape=(100,))],
local_sizes={0: 16})
knl = lp.split_dimension(knl, "i", 16, inner_tag="l.0")
knl = lp.add_prefetch(knl, "a", [])
#knl = lp.realize_cse(knl, None, np.float32, ["i_inner"])
kernel_gen = lp.generate_loop_schedules(knl)
kernel_gen = lp.check_kernels(kernel_gen)
......@@ -86,38 +85,6 @@ def test_multi_cse(ctx_factory):
def test_bad_stencil(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(ctx.devices[0],
"{[i,j]: 0<= i,j < 32}",
[
"[i] <float32> z[i,j] = -2*cse(a[i,j])"
" + cse(a[i,j-1])"
" + cse(a[i,j+1])"
" + cse(a[i-1,j])"
" + cse(a[i+1,i])" # watch out: i!
],
[
lp.ArrayArg("a", np.float32, shape=(32,32,))
])
def variant_2(knl):
knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0", inner_tag="l.0")
knl = lp.realize_cse(knl, None, np.float32, ["i_inner", "j"])
return knl
for variant in [variant_2]:
kernel_gen = lp.generate_loop_schedules(variant(knl),
loop_priority=["i_outer", "i_inner_0", "j_0"])
kernel_gen = lp.check_kernels(kernel_gen)
for knl in kernel_gen:
print lp.generate_code(knl)
def test_stencil(ctx_factory):
ctx = ctx_factory()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment