From 92d7932b5c8ad5508179c67ef91fa19482808a35 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sun, 14 Jan 2018 22:38:48 -0600
Subject: [PATCH] Fixing logical errors in `test_to_bathced` and
 `test_to_batched_temps` so that they actually check that the output is
 correct

---
 test/test_transform.py | 25 ++++++++++++++++++++++---
 1 file changed, 22 insertions(+), 3 deletions(-)

diff --git a/test/test_transform.py b/test/test_transform.py
index ccbbd3da0..0e10db362 100644
--- a/test/test_transform.py
+++ b/test/test_transform.py
@@ -96,13 +96,28 @@ def test_to_batched(ctx_factory):
     knl = lp.make_kernel(
          ''' { [i,j]: 0<=i,j<n } ''',
          ''' out[i] = sum(j, a[i,j]*x[j])''')
+    knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32,
+                                            x=np.float32,
+                                            a=np.float32))
 
     bknl = lp.to_batched(knl, "nbatches", "out,x")
 
-    a = np.random.randn(5, 5)
-    x = np.random.randn(7, 5)
+    ref_knl = lp.make_kernel(
+         ''' { [i,j,k]: 0<=i,j<n and 0<=k<nbatches} ''',
+         '''out[k, i] = sum(j, a[i,j]*x[k, j])''')
+    ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32,
+                                                    x=np.float32,
+                                                    a=np.float32))
 
-    bknl(queue, a=a, x=x)
+    a = np.random.randn(5, 5).astype(np.float32)
+    x = np.random.randn(7, 5).astype(np.float32)
+
+    # Running both the kernels
+    evt, (out1, ) = bknl(queue, a=a, x=x, n=5, nbatches=7)
+    evt, (out2, ) = ref_knl(queue, a=a, x=x, n=5, nbatches=7)
+
+    # checking that the outputs are same
+    assert np.linalg.norm(out1-out2) < 1e-15
 
 
 def test_to_batched_temp(ctx_factory):
@@ -130,9 +145,13 @@ def test_to_batched_temp(ctx_factory):
     bknl = lp.to_batched(knl, "nbatches", "out,x")
     bref_knl = lp.to_batched(ref_knl, "nbatches", "out,x")
 
+    # checking that cnst is not being bathced
+    assert bknl.temporary_variables['cnst'].shape == ()
+
     a = np.random.randn(5, 5)
     x = np.random.randn(7, 5)
 
+    # Checking that the program compiles and the logic is correct
     lp.auto_test_vs_ref(
             bref_knl, ctx, bknl,
             parameters=dict(a=a, x=x, n=5, nbatches=7))
-- 
GitLab