Newer
Older
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
def test_indexof_vec(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
if ctx.devices[0].platform.name.startswith("Portable"):
# Accurate as of 2015-10-08
pytest.skip("POCL miscompiles vector code")
knl = lp.make_kernel(
''' { [i,j,k]: 0<=i,j,k<4 } ''',
''' out[i,j,k] = indexof_vec(out[i,j,k])''')
knl = lp.tag_inames(knl, {"i": "vec"})
knl = lp.tag_data_axes(knl, "out", "vec,c,c")
knl = lp.set_options(knl, write_cl=True)
(evt, (out,)) = knl(queue)
#out = out.get()
#assert np.array_equal(out.ravel(order="C"), np.arange(25))
Andreas Klöckner
committed
def test_is_expression_equal():
from loopy.symbolic import is_expression_equal
from pymbolic import var
x = var("x")
y = var("y")
assert is_expression_equal(x+2, 2+x)
assert is_expression_equal((x+2)**2, x**2 + 4*x + 4)
assert is_expression_equal((x+y)**2, x**2 + 2*x*y + y**2)
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_atomic(ctx_factory, dtype):
ctx = ctx_factory()
if (
np.dtype(dtype).itemsize == 8
and "cl_khr_int64_base_atomics" not in ctx.devices[0].extensions):
pytest.skip("64-bit atomics not supported on device")
import pyopencl.version # noqa
if (
cl.version.VERSION < (2015, 2)
and dtype == np.int64):
pytest.skip("int64 RNG not supported in PyOpenCL < 2015.2")
knl = lp.make_kernel(
"{ [i]: 0<=i<n }",
"out[i%20] = out[i%20] + 2*a[i] {atomic}",
[
lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True),
lp.GlobalArg("a", dtype, shape=lp.auto),
"..."
],
assumptions="n>0")
ref_knl = knl
knl = lp.split_iname(knl, "i", 512)
knl = lp.split_iname(knl, "i_inner", 128, outer_tag="unr", inner_tag="g.0")
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=10000))
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_atomic_load(ctx_factory, dtype):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
from loopy.kernel.data import temp_var_scope as scopes
if (
np.dtype(dtype).itemsize == 8
and "cl_khr_int64_base_atomics" not in ctx.devices[0].extensions):
pytest.skip("64-bit atomics not supported on device")
import pyopencl.version # noqa
if (
cl.version.VERSION < (2015, 2)
and dtype == np.int64):
pytest.skip("int64 RNG not supported in PyOpenCL < 2015.2")
knl = lp.make_kernel(
<> upper = 0 {id=init_upper}
<> lower = 0 {id=init_lower}
upper = upper + i * a[i] {id=sum0,dep=init_upper}
lower = lower - b[i] {id=sum1,dep=init_lower}
""",
[
lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True),
lp.GlobalArg("a", dtype, shape=lp.auto),
lp.GlobalArg("b", dtype, shape=lp.auto),
lp.TemporaryVariable('temp', dtype, for_atomic=True,
],
silenced_warnings=["write_race(init)", "write_race(temp_sum)"])
knl = lp.split_iname(knl, "j", vec_width, inner_tag="l.0")
_, out = knl(queue, a=np.arange(n, dtype=dtype), b=np.arange(n, dtype=dtype))
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_atomic_init(dtype):
vec_width = 4
knl = lp.make_kernel(
"{ [i,j]: 0<=i<100 }",
"""
out[i%4] = 0 {id=init, atomic=init}
""",
[
lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True),
"..."
],
silenced_warnings=["write_race(init)"])
knl = lp.split_iname(knl, 'i', vec_width, inner_tag='l.0')
print(knl)
print(lp.generate_code_v2(knl).device_code())
def test_within_inames_and_reduction():
# See https://github.com/inducer/loopy/issues/24
# This is (purposefully) somewhat un-idiomatic, to replicate the conditions
# under which the above bug was found. If assignees were phi[i], then the
# iname propagation heuristic would not assume that dependent instructions
# need to run inside of 'i', and hence the forced_iname_* bits below would not
# be needed.
i1 = lp.CInstruction("i",
"doSomethingToGetPhi();",
from pymbolic.primitives import Subscript, Variable
i2 = lp.Assignment("a",
lp.Reduction("sum", "j", Subscript(Variable("phi"), Variable("j"))),
within_inames=frozenset(),
within_inames_is_final=True)
k = lp.make_kernel("{[i,j] : 0<=i,j<n}",
[i1, i2],
[
lp.GlobalArg("a", dtype=np.float32, shape=()),
lp.ValueArg("n", dtype=np.int32),
lp.TemporaryVariable("phi", dtype=np.float32, shape=("n",)),
],
target=lp.CTarget(),
)
k = lp.preprocess_kernel(k)
assert 'i' not in k.insn_inames("insn_0_j_update")
print(k.stringify(with_dependencies=True))
def test_literal_local_barrier(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(
"{ [i]: 0<=i<n }",
"""
for i
... lbarrier
end
""", seq_dependencies=True)
knl = lp.fix_parameters(knl, n=128)
ref_knl = knl
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
def test_local_barrier_mem_kind():
def __test_type(mtype, expected):
insn = '... lbarrier'
if mtype:
insn += '{mem_kind=%s}' % mtype
knl = lp.make_kernel(
"{ [i]: 0<=i<n }",
"""
for i
%s
end
""" % insn, seq_dependencies=True,
target=lp.PyOpenCLTarget())
cgr = lp.generate_code_v2(knl)
assert 'barrier(%s)' % expected in cgr.device_code()
__test_type('', 'CLK_LOCAL_MEM_FENCE')
__test_type('global', 'CLK_GLOBAL_MEM_FENCE')
__test_type('local', 'CLK_LOCAL_MEM_FENCE')
def test_kernel_splitting(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(
"{ [i]: 0<=i<n }",
"""
for i
c[i] = a[i + 1]
... gbarrier
out[i] = c[i]
end
""", seq_dependencies=True)
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
knl = lp.add_and_infer_dtypes(knl,
{"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32})
ref_knl = knl
knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
# schedule
from loopy.preprocess import preprocess_kernel
knl = preprocess_kernel(knl)
from loopy.schedule import get_one_scheduled_kernel
knl = get_one_scheduled_kernel(knl)
# map schedule onto host or device
print(knl)
cgr = lp.generate_code_v2(knl)
assert len(cgr.device_programs) == 2
print(cgr.device_code())
print(cgr.host_code())
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
def test_kernel_splitting_with_loop(ctx_factory):
knl = lp.make_kernel(
"{ [i,k]: 0<=i<n and 0<=k<3 }",
"""
for i, k
... gbarrier
c[k,i] = a[k, i + 1]
... gbarrier
out[k,i] = c[k,i]
end
""", seq_dependencies=True)
knl = lp.add_and_infer_dtypes(knl,
{"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32})
knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
# schedule
from loopy.preprocess import preprocess_kernel
knl = preprocess_kernel(knl)
from loopy.schedule import get_one_scheduled_kernel
knl = get_one_scheduled_kernel(knl)
# map schedule onto host or device
print(knl)
cgr = lp.generate_code_v2(knl)
assert len(cgr.device_programs) == 2
print(cgr.device_code())
print(cgr.host_code())
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
def save_and_reload_temporaries_test(queue, knl, out_expect, debug=False):
from loopy.preprocess import preprocess_kernel
from loopy.schedule import get_one_scheduled_kernel
knl = preprocess_kernel(knl)
knl = get_one_scheduled_kernel(knl)
from loopy.transform.save import save_and_reload_temporaries
knl = save_and_reload_temporaries(knl)
knl = get_one_scheduled_kernel(knl)
if debug:
print(knl)
cgr = lp.generate_code_v2(knl)
print(cgr.device_code())
print(cgr.host_code())
1/0
_, (out,) = knl(queue, out_host=True)
assert (out == out_expect).all(), (out, out_expect)
@pytest.mark.parametrize("hw_loop", [True, False])
def test_save_of_private_scalar(ctx_factory, hw_loop, debug=False):
"{ [i]: 0<=i<8 }",
"""
for i
<>t = i
... gbarrier
out[i] = t
end
""", seq_dependencies=True)
if hw_loop:
knl = lp.tag_inames(knl, dict(i="g.0"))
save_and_reload_temporaries_test(queue, knl, np.arange(8), debug)
"{ [i]: 0<=i<8 }",
"""
for i
<>t[i] = i
... gbarrier
out[i] = t[i]
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t", "private")
save_and_reload_temporaries_test(queue, knl, np.arange(8), debug)
def test_save_of_private_array_in_hw_loop(ctx_factory, debug=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
"{ [i,j,k]: 0<=i,j,k<8 }",
"""
for i
for j
<>t[j] = j
... gbarrier
for k
out[i,k] = t[k]
knl = lp.tag_inames(knl, dict(i="g.0"))
knl = lp.set_temporary_scope(knl, "t", "private")
save_and_reload_temporaries_test(
queue, knl, np.vstack((8 * (np.arange(8),))), debug)
def test_save_of_private_multidim_array(ctx_factory, debug=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{ [i,j,k,l,m]: 0<=i,j,k,l,m<8 }",
"""
for i
for j, k
<>t[j,k] = k
end
... gbarrier
for l, m
out[i,l,m] = t[l,m]
end
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t", "private")
result = np.array([np.vstack((8 * (np.arange(8),))) for i in range(8)])
save_and_reload_temporaries_test(queue, knl, result, debug)
def test_save_of_private_multidim_array_in_hw_loop(ctx_factory, debug=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{ [i,j,k,l,m]: 0<=i,j,k,l,m<8 }",
"""
for i
for j, k
<>t[j,k] = k
end
... gbarrier
for l, m
out[i,l,m] = t[l,m]
end
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t", "private")
knl = lp.tag_inames(knl, dict(i="g.0"))
result = np.array([np.vstack((8 * (np.arange(8),))) for i in range(8)])
save_and_reload_temporaries_test(queue, knl, result, debug)
@pytest.mark.parametrize("hw_loop", [True, False])
def test_save_of_multiple_private_temporaries(ctx_factory, hw_loop, debug=False):
for i
for k
<> t_arr[k] = k
end
<> t_scalar = 1
for j
... gbarrier
out[j] = t_scalar
... gbarrier
t_scalar = 10
<> flag = i == 9
out[i] = t_arr[i] {if=flag}
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t_arr", "private")
if hw_loop:
knl = lp.tag_inames(knl, dict(i="g.0"))
result = np.array([1, 10, 10, 10, 10, 10, 10, 10, 10, 9])
save_and_reload_temporaries_test(queue, knl, result, debug)
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{ [i,j]: 0<=i,j<8 }",
"""
for i, j
<>t[2*j] = j
t[2*j+1] = j
... gbarrier
out[i] = t[2*i]
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t", "local")
knl = lp.tag_inames(knl, dict(i="g.0", j="l.0"))
save_and_reload_temporaries_test(queue, knl, np.arange(8), debug)
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
def test_save_of_local_array_with_explicit_local_barrier(ctx_factory, debug=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{ [i,j]: 0<=i,j<8 }",
"""
for i, j
<>t[2*j] = j
... lbarrier
t[2*j+1] = t[2*j]
... gbarrier
out[i] = t[2*i]
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t", "local")
knl = lp.tag_inames(knl, dict(i="g.0", j="l.0"))
save_and_reload_temporaries_test(queue, knl, np.arange(8), debug)
def test_save_local_multidim_array(ctx_factory, debug=False):
"{ [i,j,k]: 0<=i<2 and 0<=k<3 and 0<=j<2}",
end
""", seq_dependencies=True)
knl = lp.set_temporary_scope(knl, "t_local", "local")
knl = lp.tag_inames(knl, dict(j="l.0", i="g.0"))
save_and_reload_temporaries_test(queue, knl, 1, debug)
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
def test_save_with_base_storage(ctx_factory, debug=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{[i]: 0 <= i < 10}",
"""
<>a[i] = 0
<>b[i] = i
... gbarrier
out[i] = a[i]
""",
"...",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.set_temporary_scope(knl, "a", "local")
knl = lp.set_temporary_scope(knl, "b", "local")
knl = lp.alias_temporaries(knl, ["a", "b"],
synchronize_for_exclusive_use=False)
save_and_reload_temporaries_test(queue, knl, np.arange(10), debug)
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
def test_save_ambiguous_storage_requirements():
knl = lp.make_kernel(
"{[i,j]: 0 <= i < 10 and 0 <= j < 10}",
"""
<>a[j] = j
... gbarrier
out[i,j] = a[j]
""",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="g.0", j="l.0"))
knl = lp.duplicate_inames(knl, "j", within="writes:out", tags={"j": "l.0"})
knl = lp.set_temporary_scope(knl, "a", "local")
knl = lp.preprocess_kernel(knl)
knl = lp.get_one_scheduled_kernel(knl)
from loopy.diagnostic import LoopyError
with pytest.raises(LoopyError):
lp.save_and_reload_temporaries(knl)
def test_save_across_inames_with_same_tag(ctx_factory, debug=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{[i]: 0 <= i < 10}",
"""
<>a[i] = i
... gbarrier
out[i] = a[i]
""",
"...",
seq_dependencies=True)
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.duplicate_inames(knl, "i", within="reads:a", tags={"i": "l.0"})
save_and_reload_temporaries_test(queue, knl, np.arange(10), debug)
def test_missing_temporary_definition_detection():
knl = lp.make_kernel(
"{ [i]: 0<=i<10 }",
"""
for i
<> t = 1
... gbarrier
out[i] = t
end
""", seq_dependencies=True)
from loopy.diagnostic import MissingDefinitionError
with pytest.raises(MissingDefinitionError):
lp.generate_code_v2(knl)
def test_missing_definition_check_respects_aliases():
# Based on https://github.com/inducer/loopy/issues/69
knl = lp.make_kernel("{ [i] : 0<=i<n }",
["a[i] = 0",
"c[i] = b[i]"],
temporary_variables={
"a": lp.TemporaryVariable("a",
dtype=np.float64, shape=("n",), base_storage="base"),
"b": lp.TemporaryVariable("b",
dtype=np.float64, shape=("n",), base_storage="base")
},
target=lp.CTarget(),
silenced_warnings=frozenset(["read_no_write(b)"]))
lp.generate_code_v2(knl)
def test_global_temporary(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(
"{ [i]: 0<=i<n}",
"""
for i
<> c[i] = a[i + 1]
... gbarrier
out[i] = c[i]
end
""", seq_dependencies=True)
knl = lp.add_and_infer_dtypes(knl,
{"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32})
knl = lp.set_temporary_scope(knl, "c", "global")
ref_knl = knl
knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
cgr = lp.generate_code_v2(knl)
assert len(cgr.device_programs) == 2
#print(cgr.device_code())
#print(cgr.host_code())
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
def test_assign_to_linear_subscript(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl1 = lp.make_kernel(
"{ [i]: 0<=i<n}",
"a[i,i] = 1")
knl2 = lp.make_kernel(
"{ [i]: 0<=i<n}",
"a[[i*n + i]] = 1",
[lp.GlobalArg("a", shape="n,n"), "..."])
a1 = cl.array.zeros(queue, (10, 10), np.float32)
knl1(queue, a=a1)
a2 = cl.array.zeros(queue, (10, 10), np.float32)
knl2(queue, a=a2)
assert np.array_equal(a1.get(), a2.get())
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
def test_finite_difference_expr_subst(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
grid = np.linspace(0, 2*np.pi, 2048, endpoint=False)
h = grid[1] - grid[0]
u = cl.clmath.sin(cl.array.to_device(queue, grid))
fin_diff_knl = lp.make_kernel(
"{[i]: 1<=i<=n}",
"out[i] = -(f[i+1] - f[i-1])/h",
[lp.GlobalArg("out", shape="n+2"), "..."])
flux_knl = lp.make_kernel(
"{[j]: 1<=j<=n}",
"f[j] = u[j]**2/2",
[
lp.GlobalArg("f", shape="n+2"),
lp.GlobalArg("u", shape="n+2"),
])
fused_knl = lp.fuse_kernels([fin_diff_knl, flux_knl],
data_flow=[
("f", 1, 0)
])
fused_knl = lp.set_options(fused_knl, write_cl=True)
evt, _ = fused_knl(queue, u=u, h=np.float32(1e-1))
fused_knl = lp.assignment_to_subst(fused_knl, "f")
fused_knl = lp.set_options(fused_knl, write_cl=True)
# This is the real test here: The automatically generated
# shape expressions are '2+n' and the ones above are 'n+2'.
# Is loopy smart enough to understand that these are equal?
evt, _ = fused_knl(queue, u=u, h=np.float32(1e-1))
fused0_knl = lp.affine_map_inames(fused_knl, "i", "inew", "inew+1=i")
gpu_knl = lp.split_iname(
fused0_knl, "inew", 128, outer_tag="g.0", inner_tag="l.0")
precomp_knl = lp.precompute(
gpu_knl, "f_subst", "inew_inner", fetch_bounding_box=True)
precomp_knl = lp.tag_inames(precomp_knl, {"j_0_outer": "unr"})
precomp_knl = lp.set_options(precomp_knl, return_dict=True)
evt, _ = precomp_knl(queue, u=u, h=h)
# {{{ call without returned values
def test_call_with_no_returned_value(ctx_factory):
import pymbolic.primitives as p
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
"{:}",
[lp.CallInstruction((), p.Call(p.Variable("f"), ()))]
from library_for_test import no_ret_f_mangler, no_ret_f_preamble_gen
knl = lp.register_function_manglers(knl, [no_ret_f_mangler])
knl = lp.register_preamble_generators(knl, [no_ret_f_preamble_gen])
evt, _ = knl(queue)
knl = lp.make_kernel(
"{:}",
"f() {id=init}"
)
knl = lp.register_function_manglers(knl, [no_ret_f_mangler])
print(lp.generate_code_v2(knl).device_code())
Dominic Kempf
committed
def test_unschedulable_kernel_detection():
knl = lp.make_kernel(["{[i,j]:0<=i,j<n}"],
"""
mat1[i,j] = mat1[i,j] + 1 {inames=i:j, id=i1}
mat2[j] = mat2[j] + 1 {inames=j, id=i2}
mat3[i] = mat3[i] + 1 {inames=i, id=i3}
""")
knl = lp.preprocess_kernel(knl)
# Check that loopy can detect the unschedulability of the kernel
assert not lp.has_schedulable_iname_nesting(knl)
assert len(list(lp.get_iname_duplication_options(knl))) == 4
Dominic Kempf
committed
for inames, insns in lp.get_iname_duplication_options(knl):
fixed_knl = lp.duplicate_inames(knl, inames, insns)
assert lp.has_schedulable_iname_nesting(fixed_knl)
Dominic Kempf
committed
knl = lp.make_kernel(["{[i,j,k,l,m]:0<=i,j,k,l,m<n}"],
"""
mat1[l,m,i,j,k] = mat1[l,m,i,j,k] + 1 {inames=i:j:k:l:m}
mat2[l,m,j,k] = mat2[l,m,j,k] + 1 {inames=j:k:l:m}
mat3[l,m,k] = mat3[l,m,k] + 11 {inames=k:l:m}
mat4[l,m,i] = mat4[l,m,i] + 1 {inames=i:l:m}
""")
assert not lp.has_schedulable_iname_nesting(knl)
assert len(list(lp.get_iname_duplication_options(knl))) == 10
Andreas Klöckner
committed
def test_regression_no_ret_call_removal(ctx_factory):
# https://github.com/inducer/loopy/issues/32
knl = lp.make_kernel(
"{[i] : 0<=i<n}",
"f(sum(i, x[i]))")
knl = lp.add_and_infer_dtypes(knl, {"x": np.float32})
knl = lp.preprocess_kernel(knl)
assert len(knl.instructions) == 3
def test_regression_persistent_hash():
knl1 = lp.make_kernel(
"{[i] : 0<=i<n}",
"cse_exprvar = d[2]*d[2]")
knl2 = lp.make_kernel(
"{[i] : 0<=i<n}",
"cse_exprvar = d[0]*d[0]")
from loopy.tools import LoopyKeyBuilder
lkb = LoopyKeyBuilder()
assert lkb(knl1.instructions[0]) != lkb(knl2.instructions[0])
assert lkb(knl1) != lkb(knl2)
def test_sequential_dependencies(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(
"{[i]: 0<=i<n}",
"""
for i
<> aa = 5jf
<> bb = 5j
a[i] = imag(aa)
b[i] = imag(bb)
c[i] = 5f
end
""", seq_dependencies=True)
print(knl.stringify(with_dependencies=True))
lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5))
def test_nop(ctx_factory):
knl = lp.make_kernel(
"{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
"""
for itrip,i
... nop {dep=wr_z:wr_v,id=yoink}
z[i] = z[i] - z[i+1] + v[i] {dep=yoink}
knl = lp.fix_parameters(knl, n=15)
knl = lp.add_and_infer_dtypes(knl, {"z": np.float64})
lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(ntrips=5))
def test_global_barrier(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel(
"{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
"""
for i
for itrip
... gbarrier {id=top}
<> z[i] = z[i+1] + z[i] {id=wr_z,dep=top}
<> v[i] = 11 {id=wr_v,dep=top}
... gbarrier {dep=wr_z:wr_v,id=yoink}
z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=wr_z}
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
end
... gbarrier {dep=iupd,id=postloop}
z[i] = z[i] - z[i+1] + v[i] {dep=postloop}
end
""")
knl = lp.fix_parameters(knl, ntrips=3)
knl = lp.add_and_infer_dtypes(knl, {"z": np.float64})
ref_knl = knl
ref_knl = lp.set_temporary_scope(ref_knl, "z", "global")
ref_knl = lp.set_temporary_scope(ref_knl, "v", "global")
knl = lp.split_iname(knl, "i", 256, outer_tag="g.0", inner_tag="l.0")
print(knl)
knl = lp.preprocess_kernel(knl)
assert knl.temporary_variables["z"].scope == lp.temp_var_scope.GLOBAL
assert knl.temporary_variables["v"].scope == lp.temp_var_scope.GLOBAL
print(knl)
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(ntrips=5, n=10))
def test_missing_global_barrier():
knl = lp.make_kernel(
"{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
"""
for i
for itrip
... gbarrier {id=yoink}
<> z[i] = z[i] - z[i+1] {id=iupd,dep=yoink}
end
# This is where the barrier should be
z[i] = z[i] - z[i+1] + v[i] {dep=iupd}
end
""")
knl = lp.set_temporary_scope(knl, "z", "global")
knl = lp.split_iname(knl, "i", 256, outer_tag="g.0")
knl = lp.preprocess_kernel(knl)
from loopy.diagnostic import MissingBarrierError
with pytest.raises(MissingBarrierError):
lp.get_one_scheduled_kernel(knl)
def test_index_cse(ctx_factory):
knl = lp.make_kernel(["{[i,j,k,l,m]:0<=i,j,k,l,m<n}"],
"""
for i
for j
c[i,j,m] = sum((k,l), a[i,j,l]*b[i,j,k,l])
end
end
""")
knl = lp.tag_inames(knl, "l:unr")
knl = lp.prioritize_loops(knl, "i,j,k,l")
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32})
knl = lp.fix_parameters(knl, n=5)
print(lp.generate_code_v2(knl).device_code())
def test_ilp_and_conditionals(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel('{[k]: 0<=k<n}}',
"""
for k
<> Tcond = T[k] < 0.5
if Tcond
cp[k] = 2 * T[k] + Tcond
end
end
""")
knl = lp.fix_parameters(knl, n=200)
knl = lp.add_and_infer_dtypes(knl, {"T": np.float32})
ref_knl = knl
knl = lp.split_iname(knl, 'k', 2, inner_tag='ilp')
lp.auto_test_vs_ref(ref_knl, ctx, knl)
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
def test_unr_and_conditionals(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel('{[k]: 0<=k<n}}',
"""
for k
<> Tcond[k] = T[k] < 0.5
if Tcond[k]
cp[k] = 2 * T[k] + Tcond[k]
end
end
""")
knl = lp.fix_parameters(knl, n=200)
knl = lp.add_and_infer_dtypes(knl, {"T": np.float32})
ref_knl = knl
knl = lp.split_iname(knl, 'k', 2, inner_tag='unr')
lp.auto_test_vs_ref(ref_knl, ctx, knl)
def test_constant_array_args(ctx_factory):
ctx = ctx_factory()
knl = lp.make_kernel('{[k]: 0<=k<n}}',
"""
for k
<> Tcond[k] = T[k] < 0.5
if Tcond[k]
cp[k] = 2 * T[k] + Tcond[k]
end
end
""",
[lp.ConstantArg('T', shape=(200,), dtype=np.float32),
'...'])