diff --git a/test/test_matmul.py b/test/test_matmul.py index 55504a877c752eda3474ca09f87bfb1fe5f06443..b0b3b123a33920d549c477cabfc96ae5c4f2ef4b 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -94,7 +94,7 @@ def test_axpy(ctx_factory): queue = cl.CommandQueue(ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) - n = get_suitable_size(ctx)**3 + n = 20*1024**2 knl = lp.LoopKernel(ctx.devices[0], "[n] -> {[i]: 0<=i<n}", @@ -109,31 +109,41 @@ def test_axpy(ctx_factory): lp.ArrayArg("z", dtype, shape="n,"), lp.ScalarArg("n", np.int32, approximately=n), ], - name="matmul", assumptions="n>=4096") + name="matmul") - unroll = 4 - block_size = 256 - knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, 1)) - knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0") + def variant_cpu(knl): + unroll = 16 + block_size = unroll*4096 + knl = lp.split_dimension(knl, "i", block_size, outer_tag="g.0", slabs=(0, 1)) + knl = lp.split_dimension(knl, "i_inner", unroll, inner_tag="unr") + return knl - kernel_gen = lp.generate_loop_schedules(knl) - kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=5) + def variant_gpu(knl): + unroll = 4 + block_size = 256 + knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, 1)) + knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0") + return knl - a = cl_random.rand(queue, n, dtype=dtype) - b = cl_random.rand(queue, n, dtype=dtype) + a = cl_random.rand(queue, n, dtype=dtype, luxury=2) + b = cl_random.rand(queue, n, dtype=dtype, luxury=2) c = cl_array.zeros_like(a) refsol = (2*a+3*b).get() - def launcher(kernel, gsize, lsize, check): - evt = kernel(queue, gsize(n), lsize(n), 2, a.data, 3, b.data, c.data, n, - g_times_l=True) + for variant in [variant_cpu, variant_gpu]: + kernel_gen = lp.generate_loop_schedules(variant(knl)) + kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=5) - if check: - check_error(refsol, c.get()) + def launcher(kernel, gsize, lsize, check): + evt = kernel(queue, gsize(n), lsize(n), 2, a.data, 3, b.data, c.data, n, + g_times_l=True) - return evt + if check: + check_error(refsol, c.get()) + + return evt - lp.drive_timing_run(kernel_gen, queue, launcher, 5*n) + lp.drive_timing_run(kernel_gen, queue, launcher, 5*n)