Skip to content
Snippets Groups Projects
Commit fba77761 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Base slab decomposition on absolute loop bounds.

This removes the 'spurious' conditionals in axpy.
parent 8fad429b
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,8 @@ Things to consider
TODO
^^^^
- Make axpy better.
- implemented_domain may end up being smaller than requested in cse
evaluations--check that!
......@@ -88,6 +90,8 @@ TODO
Dealt with
^^^^^^^^^^
- Screwy lower bounds in slab decomposition
- reimplement add_prefetch
- Flag, exploit idempotence
......
......@@ -32,9 +32,6 @@ def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
# {{{ conditional-minimizing slab decomposition
def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
from loopy.isl_helpers import block_shift_constraint, negate_constraint
ccm = codegen_state.c_code_mapper
space = kernel.space
tag = kernel.iname_to_tag.get(iname)
......@@ -47,35 +44,79 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
iname_tp, iname_idx = kernel.iname_to_dim[iname]
slabs = []
if lower_incr:
slabs.append(("initial", isl.Set.universe(kernel.space)
.add_constraint(lb_cns_orig)
.add_constraint(ub_cns_orig)
.add_constraint(
negate_constraint(
block_shift_constraint(
lb_cns_orig, iname_tp, iname_idx, -lower_incr)))))
slabs.append(("bulk",
(isl.Set.universe(kernel.space)
.add_constraint(
block_shift_constraint(lb_cns_orig, iname_tp, iname_idx, -lower_incr))
.add_constraint(
block_shift_constraint(ub_cns_orig, iname_tp, iname_idx, -upper_incr)))))
if upper_incr:
slabs.append(("final", isl.Set.universe(kernel.space)
.add_constraint(ub_cns_orig)
.add_constraint(lb_cns_orig)
.add_constraint(
negate_constraint(
block_shift_constraint(
ub_cns_orig, iname_tp, iname_idx, -upper_incr)))))
constraints = [lb_cns_orig]
if lower_incr or upper_incr:
bounds = kernel.get_iname_bounds(iname)
lower_bound_pw_aff_pieces = bounds.lower_bound_pw_aff.coalesce().get_pieces()
upper_bound_pw_aff_pieces = bounds.upper_bound_pw_aff.coalesce().get_pieces()
if len(lower_bound_pw_aff_pieces) > 1:
raise NotImplementedError("lower bound for slab decomp of '%s' needs "
"conditional/has more than one piece" % iname)
if len(upper_bound_pw_aff_pieces) > 1:
raise NotImplementedError("upper bound for slab decomp of '%s' needs "
"conditional/has more than one piece" % iname)
(_, lower_bound_aff), = lower_bound_pw_aff_pieces
(_, upper_bound_aff), = upper_bound_pw_aff_pieces
lower_bulk_bound = lb_cns_orig
upper_bulk_bound = lb_cns_orig
from loopy.isl_helpers import iname_rel_aff
if lower_incr:
assert lower_incr > 0
lower_slab = ("initial", isl.Set.universe(kernel.space)
.add_constraint(lb_cns_orig)
.add_constraint(ub_cns_orig)
.add_constraint(
isl.Constraint.inequality_from_aff(
iname_rel_aff(kernel.space,
iname, "<", lower_bound_aff+lower_incr))))
lower_bulk_bound = (
isl.Constraint.inequality_from_aff(
iname_rel_aff(kernel.space,
iname, ">=", lower_bound_aff+lower_incr)))
else:
lower_slab = None
if upper_incr:
assert upper_incr > 0
upper_slab = ("final", isl.Set.universe(kernel.space)
.add_constraint(lb_cns_orig)
.add_constraint(ub_cns_orig)
.add_constraint(
isl.Constraint.inequality_from_aff(
iname_rel_aff(kernel.space,
iname, ">=", upper_bound_aff-upper_incr))))
upper_bulk_bound = (
isl.Constraint.inequality_from_aff(
iname_rel_aff(kernel.space,
iname, "<", upper_bound_aff-upper_incr)))
else:
lower_slab = None
slabs = []
if lower_slab:
slabs.append(lower_slab)
slabs.append((
("bulk",
(isl.Set.universe(kernel.space)
.add_constraint(lower_bulk_bound)
.add_constraint(upper_bulk_bound)))))
if upper_slab:
slabs.append(upper_slab)
return slabs
# }}}
return lb_cns_orig, ub_cns_orig, slabs
else:
return [("bulk",
(isl.Set.universe(kernel.space)
.add_constraint(lb_cns_orig)
.add_constraint(ub_cns_orig)))]
# }}}
......@@ -166,7 +207,7 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state, hw_inames_left=
# }}}
lb_cns_orig, ub_cns_orig, slabs = get_slab_decomposition(
slabs = get_slab_decomposition(
kernel, iname, sched_index, codegen_state)
if other_inames_with_same_tag and len(slabs) > 1:
......@@ -202,7 +243,7 @@ def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
iname = kernel.schedule[sched_index].iname
tag = kernel.iname_to_tag.get(iname)
lb_cns_orig, ub_cns_orig, slabs = get_slab_decomposition(
slabs = get_slab_decomposition(
kernel, iname, sched_index, codegen_state)
result = []
......
......@@ -109,19 +109,19 @@ def test_axpy(ctx_factory):
lp.ArrayArg("z", dtype, shape="n,"),
lp.ScalarArg("n", np.int32, approximately=n),
],
name="matmul")
name="matmul", assumptions="n>=4096")
unroll = 4
block_size = 256
knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, -1))
knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0", slabs=(0, -1))
knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, 1))
knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0")
kernel_gen = lp.generate_loop_schedules(knl)
kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=5)
a = cl_random.rand(queue, n, dtype=dtype)
b = cl_random.rand(queue, n, dtype=dtype)
c = cl_array.empty_like(a)
c = cl_array.zeros_like(a)
refsol = (2*a+3*b).get()
def launcher(kernel, gsize, lsize, check):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment