Skip to content
Snippets Groups Projects
Commit 535a8755 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

Added a test for slice

parent da2d437d
Branches
Tags
2 merge requests!426Discussion: kernel_callables_v3-edit2,!246WIP: Kernel Callables
......@@ -230,6 +230,49 @@ def test_register_knl(ctx_factory):
np.linalg.norm(2*x+3*y))) < 1e-15
def test_slices(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
n = 2 ** 4
x = np.random.rand(n, n, n, n, n)
y = np.random.rand(n, n, n, n, n)
child_knl = lp.make_kernel(
"{[i, j]:0<=i, j < 16}",
"""
g[i, j] = 2*e[i, j] + 3*f[i, j]
""")
parent_knl = lp.make_kernel(
"{[i, k, m]: 0<=i, k, m<16}",
"""
z[i, :, k, :, m] = linear_combo(x[i, :, k, :, m], y[i, :, k, :, m])
""",
kernel_data=[
lp.GlobalArg(
name='x',
dtype=np.float64,
shape=(16, 16, 16, 16, 16)),
lp.GlobalArg(
name='y',
dtype=np.float64,
shape=(16, 16, 16, 16, 16)),
lp.GlobalArg(
name='z',
dtype=np.float64,
shape=(16, 16, 16, 16, 16)), '...'],
)
knl = lp.register_callable_kernel(
parent_knl, 'linear_combo', child_knl)
evt, (out, ) = knl(queue, x=x, y=y)
assert (np.linalg.norm(2*x+3*y-out)/(
np.linalg.norm(2*x+3*y))) < 1e-15
def test_rename_argument(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment