Newer
Older
args='var[k], k'),
ReductionTest('argmin',
lambda x: (
x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)),
args='var[k], k')
]
for reduction, function, args in reductions:
kstr = ("out" if 'arg' not in reduction
else "out[0], out[1]")
kstr += ' = {0}(k, {1})'.format(reduction, args)
kstr,
[var_lp, '...'])
knl = lp.fix_parameters(knl, n=200)
Nick Curtis
committed
_, (out,) = knl(queue, out_host=True)
def test_nosync_option_parsing():
knl = lp.make_kernel(
"{[i]: 0 <= i < 10}",
"""
<>t = 1 {id=insn1,nosync=insn1}
t = 2 {id=insn2,nosync=insn1:insn2}
t = 3 {id=insn3,nosync=insn1@local:insn2@global:insn3@any}
t = 4 {id=insn4,nosync_query=id:insn*@local}
t = 5 {id=insn5,nosync_query=id:insn1}
""",
options=lp.Options(allow_terminal_colors=False))
kernel_str = str(knl)
assert "# insn1,no_sync_with=insn1@any" in kernel_str
assert "# insn2,no_sync_with=insn1@any:insn2@any" in kernel_str
assert "# insn3,no_sync_with=insn1@local:insn2@global:insn3@any" in kernel_str
assert "# insn4,no_sync_with=insn1@local:insn2@local:insn3@local:insn5@local" in kernel_str # noqa
assert "# insn5,no_sync_with=insn1@any" in kernel_str
def assert_barrier_between(knl, id1, id2, ignore_barriers_in_levels=()):
from loopy.schedule import (RunInstruction, Barrier, EnterLoop, LeaveLoop)
watch_for_barrier = False
seen_barrier = False
loop_level = 0
for sched_item in knl.schedule:
if isinstance(sched_item, RunInstruction):
if sched_item.insn_id == id1:
watch_for_barrier = True
elif sched_item.insn_id == id2:
assert watch_for_barrier
assert seen_barrier
return
elif isinstance(sched_item, Barrier):
if watch_for_barrier and loop_level not in ignore_barriers_in_levels:
seen_barrier = True
elif isinstance(sched_item, EnterLoop):
loop_level += 1
elif isinstance(sched_item, LeaveLoop):
loop_level -= 1
raise RuntimeError("id2 was not seen")
def test_barrier_insertion_near_top_of_loop():
knl = lp.make_kernel(
"{[i,j]: 0 <= i,j < 10 }",
"""
for i
<>a[i] = i {id=ainit}
for j
<>t = a[(i + 1) % 10] {id=tcomp}
<>b[i,j] = a[i] + t {id=bcomp1}
b[i,j] = b[i,j] + 1 {id=bcomp2}
end
end
""",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.set_temporary_scope(knl, "a", "local")
knl = lp.set_temporary_scope(knl, "b", "local")
knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))
print(knl)
assert_barrier_between(knl, "ainit", "tcomp")
assert_barrier_between(knl, "tcomp", "bcomp1")
assert_barrier_between(knl, "bcomp1", "bcomp2")
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
def test_barrier_insertion_near_bottom_of_loop():
knl = lp.make_kernel(
["{[i]: 0 <= i < 10 }",
"[jmax] -> {[j]: 0 <= j < jmax}"],
"""
for i
<>a[i] = i {id=ainit}
for j
<>b[i,j] = a[i] + t {id=bcomp1}
b[i,j] = b[i,j] + 1 {id=bcomp2}
end
a[i] = i + 1 {id=aupdate}
end
""",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.set_temporary_scope(knl, "a", "local")
knl = lp.set_temporary_scope(knl, "b", "local")
knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))
print(knl)
assert_barrier_between(knl, "bcomp1", "bcomp2")
assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1])
def test_global_barrier_order_finding():
knl = lp.make_kernel(
"{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
"""
for i
for itrip
... gbarrier {id=top}
<> z[i] = z[i+1] + z[i] {id=wr_z,dep=top}
<> v[i] = 11 {id=wr_v,dep=top}
... gbarrier {dep=wr_z:wr_v,id=yoink}
z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=yoink}
end
... nop {id=nop}
... gbarrier {dep=iupd,id=postloop}
z[i] = z[i] - z[i+1] + v[i] {id=zzzv,dep=postloop}
end
""")
assert lp.get_global_barrier_order(knl) == ("top", "yoink", "postloop")
for insn, barrier in (
("nop", None),
("top", None),
("wr_z", "top"),
("wr_v", "top"),
("yoink", "top"),
("postloop", "yoink"),
("zzzv", "postloop")):
assert lp.find_most_recent_global_barrier(knl, insn) == barrier
def test_global_barrier_error_if_unordered():
# FIXME: Should be illegal to declare this
knl = lp.make_kernel("{[i]: 0 <= i < 10}",
"""
... gbarrier
... gbarrier
""")
from loopy.diagnostic import LoopyError
with pytest.raises(LoopyError):
lp.get_global_barrier_order(knl)
def test_multi_argument_reduction_type_inference():
from loopy.type_inference import TypeInferenceMapper
from loopy.library.reduction import SegmentedSumReductionOperation
from loopy.types import to_loopy_type
op = SegmentedSumReductionOperation()
knl = lp.make_kernel("{[i,j]: 0<=i<10 and 0<=j<i}", "")
int32 = to_loopy_type(np.int32)
expr = lp.symbolic.Reduction(
operation=op,
inames=("i",),
expr=lp.symbolic.Reduction(
operation=op,
inames="j",
expr=(1, 2),
allow_simultaneous=True),
allow_simultaneous=True)
t_inf_mapper = TypeInferenceMapper(knl)
assert (
t_inf_mapper(expr, return_tuple=True, return_dtype_set=True)
== [(int32, int32)])
def test_multi_argument_reduction_parsing():
from loopy.symbolic import parse, Reduction
assert isinstance(
parse("reduce(argmax, i, reduce(argmax, j, i, j))").expr,
Reduction)
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
def test_struct_assignment(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
bbhit = np.dtype([
("tmin", np.float32),
("tmax", np.float32),
("bi", np.int32),
("hit", np.int32)])
bbhit, bbhit_c_decl = cl.tools.match_dtype_to_c_struct(
ctx.devices[0], "bbhit", bbhit)
bbhit = cl.tools.get_or_register_dtype('bbhit', bbhit)
preamble = bbhit_c_decl
knl = lp.make_kernel(
"{ [i]: 0<=i<N }",
"""
for i
result[i].hit = i % 2
result[i].tmin = i
result[i].tmax = i+10
result[i].bi = i
end
""",
[
lp.GlobalArg("result", shape=("N",), dtype=bbhit),
"..."],
preambles=[("000", preamble)])
knl = lp.set_options(knl, write_cl=True)
knl(queue, N=200)
def test_kernel_var_name_generator():
knl = lp.make_kernel(
"{[i]: 0 <= i <= 10}",
"""
<>a = 0
<>b_s0 = 0
""")
vng = knl.get_var_name_generator()
assert vng("a_s0") != "a_s0"
assert vng("b") != "b"
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from py.test.cmdline import main
main([__file__])