Skip to content
Snippets Groups Projects
Commit 92d7932b authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

Fixing logical errors in `test_to_bathced` and `test_to_batched_temps` so that...

Fixing logical errors in `test_to_bathced` and `test_to_batched_temps` so that they actually check that the output is correct
parent feb226f2
No related branches found
No related tags found
No related merge requests found
......@@ -96,13 +96,28 @@ def test_to_batched(ctx_factory):
knl = lp.make_kernel(
''' { [i,j]: 0<=i,j<n } ''',
''' out[i] = sum(j, a[i,j]*x[j])''')
knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32,
x=np.float32,
a=np.float32))
bknl = lp.to_batched(knl, "nbatches", "out,x")
a = np.random.randn(5, 5)
x = np.random.randn(7, 5)
ref_knl = lp.make_kernel(
''' { [i,j,k]: 0<=i,j<n and 0<=k<nbatches} ''',
'''out[k, i] = sum(j, a[i,j]*x[k, j])''')
ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32,
x=np.float32,
a=np.float32))
bknl(queue, a=a, x=x)
a = np.random.randn(5, 5).astype(np.float32)
x = np.random.randn(7, 5).astype(np.float32)
# Running both the kernels
evt, (out1, ) = bknl(queue, a=a, x=x, n=5, nbatches=7)
evt, (out2, ) = ref_knl(queue, a=a, x=x, n=5, nbatches=7)
# checking that the outputs are same
assert np.linalg.norm(out1-out2) < 1e-15
def test_to_batched_temp(ctx_factory):
......@@ -130,9 +145,13 @@ def test_to_batched_temp(ctx_factory):
bknl = lp.to_batched(knl, "nbatches", "out,x")
bref_knl = lp.to_batched(ref_knl, "nbatches", "out,x")
# checking that cnst is not being bathced
assert bknl.temporary_variables['cnst'].shape == ()
a = np.random.randn(5, 5)
x = np.random.randn(7, 5)
# Checking that the program compiles and the logic is correct
lp.auto_test_vs_ref(
bref_knl, ctx, bknl,
parameters=dict(a=a, x=x, n=5, nbatches=7))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment