Skip to content
Snippets Groups Projects
Commit feb226f2 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

Compared the correctness of the test using `auto_test_vs_ref`

parent c98820b8
No related branches found
No related tags found
1 merge request!204To batched private temps
Pipeline #
......@@ -107,7 +107,6 @@ def test_to_batched(ctx_factory):
def test_to_batched_temp(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
''' { [i,j]: 0<=i,j<n } ''',
......@@ -115,16 +114,28 @@ def test_to_batched_temp(ctx_factory):
out[i] = sum(j, cnst*a[i,j]*x[j])''',
[lp.TemporaryVariable(
"cnst",
dtype=np.float64,
dtype=np.float32,
shape=(),
scope=lp.temp_var_scope.PRIVATE), '...'])
knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32,
x=np.float32,
a=np.float32))
ref_knl = lp.make_kernel(
''' { [i,j]: 0<=i,j<n } ''',
'''out[i] = sum(j, 2.0*a[i,j]*x[j])''')
ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32,
x=np.float32,
a=np.float32))
bknl = lp.to_batched(knl, "nbatches", "out,x")
bref_knl = lp.to_batched(ref_knl, "nbatches", "out,x")
a = np.random.randn(5, 5)
x = np.random.randn(7, 5)
bknl(queue, a=a, x=x)
lp.auto_test_vs_ref(
bref_knl, ctx, bknl,
parameters=dict(a=a, x=x, n=5, nbatches=7))
def test_add_barrier(ctx_factory):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment