Newer
Older
knl = lp.fix_parameters(knl, n=200)
@pytest.mark.parametrize("src_order", ["C"])
@pytest.mark.parametrize("tmp_order", ["C", "F"])
def test_temp_initializer(ctx_factory, src_order, tmp_order):
a = np.random.randn(3, 3).copy(order=src_order)
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{[i,j]: 0<=i,j<n}",
"out[i,j] = tmp[i,j]",
[
lp.TemporaryVariable("tmp",
initializer=a,
shape=lp.auto,
scope=lp.temp_var_scope.PRIVATE,
read_only=True,
order=tmp_order),
"..."
])
knl = lp.set_options(knl, write_cl=True, highlight_cl=True)
knl = lp.fix_parameters(knl, n=a.shape[0])
evt, (a2,) = knl(queue, out_host=True)
assert np.array_equal(a, a2)
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
def test_const_temp_with_initializer_not_saved():
knl = lp.make_kernel(
"{[i]: 0<=i<10}",
"""
... gbarrier
out[i] = tmp[i]
""",
[
lp.TemporaryVariable("tmp",
initializer=np.arange(10),
shape=lp.auto,
scope=lp.temp_var_scope.PRIVATE,
read_only=True),
"..."
],
seq_dependencies=True)
knl = lp.preprocess_kernel(knl)
knl = lp.get_one_scheduled_kernel(knl)
knl = lp.save_and_reload_temporaries(knl)
# This ensures no save slot was added.
assert len(knl.temporary_variables) == 1
def test_header_extract():
knl = lp.make_kernel('{[k]: 0<=k<n}}',
"""
for k
T[k] = k**2
end
""",
[lp.GlobalArg('T', shape=(200,), dtype=np.float32),
'...'])
knl = lp.fix_parameters(knl, n=200)
#test C
cknl = knl.copy(target=lp.CTarget())
assert str(lp.generate_header(cknl)[0]) == (
'void loopy_kernel(float *__restrict__ T);')
cuknl = knl.copy(target=lp.CudaTarget())
assert str(lp.generate_header(cuknl)[0]) == (
'extern "C" __global__ void __launch_bounds__(1) '
'loopy_kernel(float *__restrict__ T);')
oclknl = knl.copy(target=lp.PyOpenCLTarget())
assert str(lp.generate_header(oclknl)[0]) == (
'__kernel void __attribute__ ((reqd_work_group_size(1, 1, 1))) '
'loopy_kernel(__global float *__restrict__ T);')
def test_scalars_with_base_storage(ctx_factory):
""" Regression test for !50 """
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{ [i]: 0<=i<1}",
"a = 1",
[lp.TemporaryVariable("a", dtype=np.float64,
shape=(), base_storage="base")])
knl(queue, out_host=True)
def test_if_else(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{ [i]: 0<=i<50}",
"""
if i % 3 == 0
a[i] = 15 {nosync_query=writes:a}
a[i] = 11 {nosync_query=writes:a}
a[i] = 3 {nosync_query=writes:a}
end
"""
)
evt, (out,) = knl(queue, out_host=True)
out_ref = np.empty(50)
out_ref[::3] = 15
out_ref[1::3] = 11
out_ref[2::3] = 3
assert np.array_equal(out_ref, out)
knl = lp.make_kernel(
"{ [i]: 0<=i<50}",
"""
for i
if i % 2 == 0
if i % 3 == 0
a[i] = 15 {nosync_query=writes:a}
a[i] = 11 {nosync_query=writes:a}
a[i] = 3 {nosync_query=writes:a}
a[i] = 4 {nosync_query=writes:a}
end
end
"""
)
evt, (out,) = knl(queue, out_host=True)
out_ref = np.zeros(50)
out_ref[1::2] = 4
out_ref[0::6] = 15
out_ref[4::6] = 11
out_ref[2::6] = 3
knl = lp.make_kernel(
"{ [i,j]: 0<=i,j<50}",
"""
for i
if i < 25
for j
if j % 2 == 0
a[i, j] = 1 {nosync_query=writes:a}
a[i, j] = 0 {nosync_query=writes:a}
end
end
else
for j
if j % 2 == 0
a[i, j] = 0 {nosync_query=writes:a}
a[i, j] = 1 {nosync_query=writes:a}
end
end
end
end
"""
)
evt, (out,) = knl(queue, out_host=True)
out_ref = np.zeros((50, 50))
out_ref[:25, 0::2] = 1
out_ref[25:, 1::2] = 1
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
def test_tight_loop_bounds(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
["{ [i] : 0 <= i <= 5 }",
"[i] -> { [j] : 2 * i - 2 < j <= 2 * i and 0 <= j <= 9 }"],
"""
for i
for j
out[j] = j
end
end
""",
silenced_warnings="write_race(insn)")
knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0")
evt, (out,) = knl(queue, out_host=True)
assert (out == np.arange(10)).all()
def test_tight_loop_bounds_codegen():
knl = lp.make_kernel(
["{ [i] : 0 <= i <= 5 }",
"[i] -> { [j] : 2 * i - 2 <= j <= 2 * i and 0 <= j <= 9 }"],
"""
for i
for j
out[j] = j
end
end
""",
silenced_warnings="write_race(insn)",
target=lp.OpenCLTarget())
knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0")
cgr = lp.generate_code_v2(knl)
#print(cgr.device_code())
Andreas Klöckner
committed
for_loop = \
"(gid(0) == 0 && lid(0) == 0 ? 0 : -2 + 2 * lid(0) + 10 * gid(0)); " \
"j <= (-1 + gid(0) == 0 && lid(0) == 0 ? 9 : 2 * lid(0)); ++j)"
Andreas Klöckner
committed
assert for_loop in cgr.device_code()
Andreas Klöckner
committed
def test_unscheduled_insn_detection():
knl = lp.make_kernel(
"{ [i]: 0 <= i < 10 }",
"""
out[i] = i {id=insn1}
""",
"...")
knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))
insn1, = lp.find_instructions(knl, "id:insn1")
knl.instructions.append(insn1.copy(id="insn2"))
from loopy.diagnostic import UnscheduledInstructionError
with pytest.raises(UnscheduledInstructionError):
lp.generate_code(knl)
def test_integer_reduction(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
from loopy.kernel.data import temp_var_scope as scopes
var_int = np.random.randint(1000, size=n).astype(vtype)
var_lp = lp.TemporaryVariable('var', initializer=var_int,
read_only=True,
from collections import namedtuple
ReductionTest = namedtuple('ReductionTest', 'kind, check, args')
reductions = [
ReductionTest('max', lambda x: x == np.max(var_int), args='var[k]'),
ReductionTest('min', lambda x: x == np.min(var_int), args='var[k]'),
ReductionTest('sum', lambda x: x == np.sum(var_int), args='var[k]'),
ReductionTest('product', lambda x: x == np.prod(var_int), args='var[k]'),
ReductionTest('argmax',
lambda x: (
x[0] == np.max(var_int) and var_int[out[1]] == np.max(var_int)),
args='var[k], k'),
ReductionTest('argmin',
lambda x: (
x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)),
args='var[k], k')
]
for reduction, function, args in reductions:
kstr = ("out" if 'arg' not in reduction
else "out[0], out[1]")
kstr += ' = {0}(k, {1})'.format(reduction, args)
kstr,
[var_lp, '...'])
knl = lp.fix_parameters(knl, n=200)
Nick Curtis
committed
_, (out,) = knl(queue, out_host=True)
Matt Wala
committed
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
def test_complicated_argmin_reduction(ctx_factory):
cl_ctx = ctx_factory()
knl = lp.make_kernel(
"{[ictr,itgt,idim]: "
"0<=itgt<ntargets "
"and 0<=ictr<ncenters "
"and 0<=idim<ambient_dim}",
"""
for itgt
for ictr
<> dist_sq = sum(idim,
(tgt[idim,itgt] - center[idim,ictr])**2)
<> in_disk = dist_sq < (radius[ictr]*1.05)**2
<> matches = (
(in_disk
and qbx_forced_limit == 0)
or (in_disk
and qbx_forced_limit != 0
and qbx_forced_limit * center_side[ictr] > 0)
)
<> post_dist_sq = if(matches, dist_sq, HUGE)
end
<> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq)
tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1)
end
""")
knl = lp.fix_parameters(knl, ambient_dim=2)
knl = lp.add_and_infer_dtypes(knl, {
"tgt,center,radius,HUGE": np.float32,
"center_side,qbx_forced_limit": np.int32,
})
lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={
"HUGE": 1e20, "ncenters": 200, "ntargets": 300,
"qbx_forced_limit": 1})
def test_nosync_option_parsing():
knl = lp.make_kernel(
"{[i]: 0 <= i < 10}",
"""
<>t = 1 {id=insn1,nosync=insn1}
t = 2 {id=insn2,nosync=insn1:insn2}
t = 3 {id=insn3,nosync=insn1@local:insn2@global:insn3@any}
t = 4 {id=insn4,nosync_query=id:insn*@local}
t = 5 {id=insn5,nosync_query=id:insn1}
""",
options=lp.Options(allow_terminal_colors=False))
kernel_str = str(knl)
print(kernel_str)
assert "id=insn1, no_sync_with=insn1@any" in kernel_str
assert "id=insn2, no_sync_with=insn1@any:insn2@any" in kernel_str
assert "id=insn3, no_sync_with=insn1@local:insn2@global:insn3@any" in kernel_str
assert "id=insn4, no_sync_with=insn1@local:insn2@local:insn3@local:insn5@local" in kernel_str # noqa
assert "id=insn5, no_sync_with=insn1@any" in kernel_str
def assert_barrier_between(knl, id1, id2, ignore_barriers_in_levels=()):
from loopy.schedule import (RunInstruction, Barrier, EnterLoop, LeaveLoop)
watch_for_barrier = False
seen_barrier = False
loop_level = 0
for sched_item in knl.schedule:
if isinstance(sched_item, RunInstruction):
if sched_item.insn_id == id1:
watch_for_barrier = True
elif sched_item.insn_id == id2:
assert watch_for_barrier
assert seen_barrier
return
elif isinstance(sched_item, Barrier):
if watch_for_barrier and loop_level not in ignore_barriers_in_levels:
seen_barrier = True
elif isinstance(sched_item, EnterLoop):
loop_level += 1
elif isinstance(sched_item, LeaveLoop):
loop_level -= 1
raise RuntimeError("id2 was not seen")
def test_barrier_insertion_near_top_of_loop():
knl = lp.make_kernel(
"{[i,j]: 0 <= i,j < 10 }",
"""
for i
<>a[i] = i {id=ainit}
for j
<>t = a[(i + 1) % 10] {id=tcomp}
<>b[i,j] = a[i] + t {id=bcomp1}
b[i,j] = b[i,j] + 1 {id=bcomp2}
end
end
""",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.set_temporary_scope(knl, "a", "local")
knl = lp.set_temporary_scope(knl, "b", "local")
knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))
print(knl)
assert_barrier_between(knl, "ainit", "tcomp")
assert_barrier_between(knl, "tcomp", "bcomp1")
assert_barrier_between(knl, "bcomp1", "bcomp2")
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
def test_barrier_insertion_near_bottom_of_loop():
knl = lp.make_kernel(
["{[i]: 0 <= i < 10 }",
"[jmax] -> {[j]: 0 <= j < jmax}"],
"""
for i
<>a[i] = i {id=ainit}
for j
<>b[i,j] = a[i] + t {id=bcomp1}
b[i,j] = b[i,j] + 1 {id=bcomp2}
end
a[i] = i + 1 {id=aupdate}
end
""",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.set_temporary_scope(knl, "a", "local")
knl = lp.set_temporary_scope(knl, "b", "local")
knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))
print(knl)
assert_barrier_between(knl, "bcomp1", "bcomp2")
assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1])
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
def test_barrier_in_overridden_get_grid_size_expanded_kernel():
from loopy.kernel.data import temp_var_scope as scopes
# make simple barrier'd kernel
knl = lp.make_kernel('{[i]: 0 <= i < 10}',
"""
for i
a[i] = i {id=a}
... lbarrier {id=barrier}
b[i + 1] = a[i] {nosync=a}
end
""",
[lp.TemporaryVariable("a", np.float32, shape=(10,), order='C',
scope=scopes.LOCAL),
lp.GlobalArg("b", np.float32, shape=(11,), order='C')],
seq_dependencies=True)
# split into kernel w/ vesize larger than iname domain
vecsize = 16
knl = lp.split_iname(knl, 'i', vecsize, inner_tag='l.0')
Andreas Klöckner
committed
from testlib import GridOverride
Andreas Klöckner
committed
# artifically expand via overridden_get_grid_sizes_for_insn_ids
knl = knl.copy(overridden_get_grid_sizes_for_insn_ids=GridOverride(
knl.copy(), vecsize))
# make sure we can generate the code
lp.generate_code_v2(knl)
def test_multi_argument_reduction_type_inference():
from loopy.type_inference import TypeInferenceMapper
from loopy.library.reduction import SegmentedSumReductionOperation
from loopy.types import to_loopy_type
op = SegmentedSumReductionOperation()
knl = lp.make_kernel("{[i,j]: 0<=i<10 and 0<=j<i}", "")
int32 = to_loopy_type(np.int32)
expr = lp.symbolic.Reduction(
operation=op,
inames=("i",),
expr=lp.symbolic.Reduction(
operation=op,
inames="j",
expr=(1, 2),
allow_simultaneous=True),
allow_simultaneous=True)
t_inf_mapper = TypeInferenceMapper(knl)
assert (
t_inf_mapper(expr, return_tuple=True, return_dtype_set=True)
== [(int32, int32)])
def test_multi_argument_reduction_parsing():
from loopy.symbolic import parse, Reduction
assert isinstance(
parse("reduce(argmax, i, reduce(argmax, j, i, j))").expr,
Reduction)
def test_global_barrier_order_finding():
knl = lp.make_kernel(
"{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
"""
for i
for itrip
... gbarrier {id=top}
<> z[i] = z[i+1] + z[i] {id=wr_z,dep=top}
<> v[i] = 11 {id=wr_v,dep=top}
... gbarrier {dep=wr_z:wr_v,id=yoink}
z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=yoink}
end
... nop {id=nop}
... gbarrier {dep=iupd,id=postloop}
z[i] = z[i] - z[i+1] + v[i] {id=zzzv,dep=postloop}
end
""")
assert lp.get_global_barrier_order(knl) == ("top", "yoink", "postloop")
for insn, barrier in (
("nop", None),
("top", None),
("wr_z", "top"),
("wr_v", "top"),
("yoink", "top"),
("postloop", "yoink"),
("zzzv", "postloop")):
assert lp.find_most_recent_global_barrier(knl, insn) == barrier
def test_global_barrier_error_if_unordered():
# FIXME: Should be illegal to declare this
knl = lp.make_kernel("{[i]: 0 <= i < 10}",
"""
... gbarrier
... gbarrier
""")
from loopy.diagnostic import LoopyError
with pytest.raises(LoopyError):
lp.get_global_barrier_order(knl)
def test_struct_assignment(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
bbhit = np.dtype([
("tmin", np.float32),
("tmax", np.float32),
("bi", np.int32),
("hit", np.int32)])
bbhit, bbhit_c_decl = cl.tools.match_dtype_to_c_struct(
ctx.devices[0], "bbhit", bbhit)
bbhit = cl.tools.get_or_register_dtype('bbhit', bbhit)
preamble = bbhit_c_decl
knl = lp.make_kernel(
"{ [i]: 0<=i<N }",
"""
for i
result[i].hit = i % 2 {nosync_query=writes:result}
result[i].tmin = i {nosync_query=writes:result}
result[i].tmax = i+10 {nosync_query=writes:result}
result[i].bi = i {nosync_query=writes:result}
end
""",
[
lp.GlobalArg("result", shape=("N",), dtype=bbhit),
"..."],
preambles=[("000", preamble)])
knl = lp.set_options(knl, write_cl=True)
knl(queue, N=200)
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
def test_inames_conditional_generation(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(
"{[i,j,k]: 0 < k < i and 0 < j < 10 and 0 < i < 10}",
"""
for k
... gbarrier
<>tmp1 = 0
end
for j
... gbarrier
<>tmp2 = i
end
""",
"...",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="g.0"))
with cl.CommandQueue(ctx) as queue:
knl(queue)
def test_kernel_var_name_generator():
knl = lp.make_kernel(
"{[i]: 0 <= i <= 10}",
"""
<>a = 0
<>b_s0 = 0
""")
vng = knl.get_var_name_generator()
assert vng("a_s0") != "a_s0"
assert vng("b") != "b"
def test_fixed_parameters(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"[n] -> {[i]: 0 <= i < n}",
"""
<>tmp[i] = i {id=init}
tmp[0] = 0 {dep=init}
""",
fixed_parameters=dict(n=1))
knl(queue)
def test_parameter_inference():
knl = lp.make_kernel("{[i]: 0 <= i < n and i mod 2 = 0}", "")
assert knl.all_params() == set(["n"])
def test_execution_backend_can_cache_dtypes(ctx_factory):
# When the kernel is invoked, the execution backend uses it as a cache key
# for the type inference and scheduling cache. This tests to make sure that
# dtypes in the kernel can be cached, even though they may not have a
# target.
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel("{[i]: 0 <= i < 10}", "<>tmp[i] = i")
knl = lp.add_dtypes(knl, dict(tmp=int))
knl(queue)
def test_wildcard_dep_matching():
knl = lp.make_kernel(
"{[i]: 0 <= i < 10}",
"""
<>a = 0 {id=insn1}
<>b = 0 {id=insn2,dep=insn?}
<>c = 0 {id=insn3,dep=insn*}
<>d = 0 {id=insn4,dep=insn[12]}
<>e = 0 {id=insn5,dep=insn[!1]}
""",
"...")
assert knl.id_to_insn["insn1"].depends_on == set()
assert knl.id_to_insn["insn2"].depends_on == all_insns - set(["insn2"])
assert knl.id_to_insn["insn3"].depends_on == all_insns - set(["insn3"])
assert knl.id_to_insn["insn4"].depends_on == set(["insn1", "insn2"])
assert knl.id_to_insn["insn5"].depends_on == all_insns - set(["insn1", "insn5"])
def test_preamble_with_separate_temporaries(ctx_factory):
from loopy.kernel.data import temp_var_scope as scopes
# create a function mangler
# and finally create a test
n = 10
# for each entry come up with a random number of data points
num_data = np.asarray(np.random.randint(2, 10, size=n), dtype=np.int32)
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
# turn into offsets
offsets = np.asarray(np.hstack(([0], np.cumsum(num_data))), dtype=np.int32)
# create lookup data
lookup = np.empty(0)
for i in num_data:
lookup = np.hstack((lookup, np.arange(i)))
lookup = np.asarray(lookup, dtype=np.int32)
# and create data array
data = np.random.rand(np.product(num_data))
# make kernel
kernel = lp.make_kernel('{[i]: 0 <= i < n}',
"""
for i
<>ind = indirect(offsets[i], offsets[i + 1], 1)
out[i] = data[ind]
end
""",
[lp.GlobalArg('out', shape=('n',)),
lp.TemporaryVariable(
'offsets', shape=(offsets.size,), initializer=offsets, scope=scopes.GLOBAL,
read_only=True),
lp.GlobalArg('data', shape=(data.size,), dtype=np.float64)],
)
# fixt params, and add manglers / preamble
Andreas Klöckner
committed
from testlib import SeparateTemporariesPreambleTestHelper
preamble_with_sep_helper = SeparateTemporariesPreambleTestHelper(
func_name='indirect',
func_arg_dtypes=(np.int32, np.int32, np.int32),
func_result_dtypes=(np.int32,),
arr=lookup
)
Andreas Klöckner
committed
kernel = lp.register_preamble_generators(
kernel, [preamble_with_sep_helper.preamble_gen])
kernel = lp.register_function_manglers(
kernel, [preamble_with_sep_helper.mangler])
print(lp.generate_code(kernel)[0])
# and call (functionality unimportant, more that it compiles)
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
# check that it actually performs the lookup correctly
assert np.allclose(kernel(
queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1])
def test_arg_inference_for_predicates():
knl = lp.make_kernel("{[i]: 0 <= i < 10}",
"""
if incr[i]
a = a + 1
end
""")
assert "incr" in knl.arg_dict
assert knl.arg_dict["incr"].shape == (10,)
def test_relaxed_stride_checks(ctx_factory):
# Check that loopy is compatible with numpy's relaxed stride rules.
ctx = ctx_factory()
knl = lp.make_kernel("{[i,j]: 0 <= i <= n and 0 <= j <= m}",
"""
a[i] = sum(j, A[i,j] * b[j])
""")
with cl.CommandQueue(ctx) as queue:
assert a == 0
def test_add_prefetch_works_in_lhs_index():
knl = lp.make_kernel(
"{ [n,k,l,k1,l1,k2,l2]: "
"start<=n<end and 0<=k,k1,k2<3 and 0<=l,l1,l2<2 }",
"""
for n
<> a1_tmp[k,l] = a1[a1_map[n, k],l]
a1_tmp[k1,l1] = a1_tmp[k1,l1] + 1
a1_out[a1_map[n,k2], l2] = a1_tmp[k2,l2]
end
""",
[
lp.GlobalArg("a1,a1_out", None, "ndofs,2"),
lp.GlobalArg("a1_map", None, "nelements,3"),
])
knl = lp.add_prefetch(knl, "a1_map", "k")
from loopy.symbolic import get_dependencies
for insn in knl.instructions:
assert "a1_map" not in get_dependencies(insn.assignees)
def test_check_for_variable_access_ordering():
knl = lp.make_kernel(
"{[i]: 0<=i<n}",
"""
a[i] = 12
a[i+1] = 13
""")
knl = lp.preprocess_kernel(knl)
from loopy.diagnostic import VariableAccessNotOrdered
with pytest.raises(VariableAccessNotOrdered):
lp.get_one_scheduled_kernel(knl)
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from py.test.cmdline import main
main([__file__])