diff --git a/test/test_matmul.py b/test/test_matmul.py
index 55504a877c752eda3474ca09f87bfb1fe5f06443..b0b3b123a33920d549c477cabfc96ae5c4f2ef4b 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -94,7 +94,7 @@ def test_axpy(ctx_factory):
     queue = cl.CommandQueue(ctx,
             properties=cl.command_queue_properties.PROFILING_ENABLE)
 
-    n = get_suitable_size(ctx)**3
+    n = 20*1024**2
 
     knl = lp.LoopKernel(ctx.devices[0],
             "[n] -> {[i]: 0<=i<n}",
@@ -109,31 +109,41 @@ def test_axpy(ctx_factory):
                 lp.ArrayArg("z", dtype, shape="n,"),
                 lp.ScalarArg("n", np.int32, approximately=n),
                 ],
-            name="matmul", assumptions="n>=4096")
+            name="matmul")
 
-    unroll = 4
-    block_size = 256
-    knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, 1))
-    knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0")
+    def variant_cpu(knl):
+        unroll = 16
+        block_size = unroll*4096
+        knl = lp.split_dimension(knl, "i", block_size, outer_tag="g.0", slabs=(0, 1))
+        knl = lp.split_dimension(knl, "i_inner", unroll, inner_tag="unr")
+        return knl
 
-    kernel_gen = lp.generate_loop_schedules(knl)
-    kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=5)
+    def variant_gpu(knl):
+        unroll = 4
+        block_size = 256
+        knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, 1))
+        knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0")
+        return knl
 
-    a = cl_random.rand(queue, n, dtype=dtype)
-    b = cl_random.rand(queue, n, dtype=dtype)
+    a = cl_random.rand(queue, n, dtype=dtype, luxury=2)
+    b = cl_random.rand(queue, n, dtype=dtype, luxury=2)
     c = cl_array.zeros_like(a)
     refsol = (2*a+3*b).get()
 
-    def launcher(kernel, gsize, lsize, check):
-        evt = kernel(queue, gsize(n), lsize(n), 2, a.data, 3, b.data, c.data, n,
-                g_times_l=True)
+    for variant in [variant_cpu, variant_gpu]:
+        kernel_gen = lp.generate_loop_schedules(variant(knl))
+        kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=5)
 
-        if check:
-            check_error(refsol, c.get())
+        def launcher(kernel, gsize, lsize, check):
+            evt = kernel(queue, gsize(n), lsize(n), 2, a.data, 3, b.data, c.data, n,
+                    g_times_l=True)
 
-        return evt
+            if check:
+                check_error(refsol, c.get())
+
+            return evt
 
-    lp.drive_timing_run(kernel_gen, queue, launcher, 5*n)
+        lp.drive_timing_run(kernel_gen, queue, launcher, 5*n)