Skip to content
Snippets Groups Projects
Commit 480223d2 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add transpose test.

parent 4853b6d5
No related branches found
No related tags found
No related merge requests found
......@@ -150,6 +150,54 @@ def test_axpy(ctx_factory):
def test_transpose(ctx_factory):
dtype = np.float32
ctx = ctx_factory()
order = "C"
queue = cl.CommandQueue(ctx,
properties=cl.command_queue_properties.PROFILING_ENABLE)
n = get_suitable_size(ctx)
knl = lp.make_kernel(ctx.devices[0],
"{[i,j]: 0<=i,j<%d}" % n,
[
"b[i, j] = a[j, i]"
],
[
lp.ArrayArg("a", dtype, shape=(n, n), order=order),
lp.ArrayArg("b", dtype, shape=(n, n), order=order),
],
name="transpose")
knl = lp.split_dimension(knl, "i", 16,
outer_tag="g.0", inner_tag="l.1")
knl = lp.split_dimension(knl, "j", 16,
outer_tag="g.1", inner_tag="l.0")
knl = lp.add_prefetch(knl, 'a', ["i_inner", "j_inner"])
knl = lp.add_prefetch(knl, 'b', ["j_inner", "k_inner", ])
kernel_gen = lp.generate_loop_schedules(knl)
kernel_gen = lp.check_kernels(kernel_gen, {}, kill_level_min=5)
a = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order)
b = cl_array.empty_like(a)
refsol = a.get().T.copy()
def launcher(kernel, gsize, lsize, check):
evt = kernel(queue, gsize(), lsize(), a.data, b.data,
g_times_l=True)
if check:
check_error(refsol, b.get())
return evt
lp.drive_timing_run(kernel_gen, queue, launcher, 0)
def test_plain_matrix_mul(ctx_factory):
dtype = np.float32
ctx = ctx_factory()
......@@ -838,47 +886,7 @@ def main_elwise_scaled_matrix_mul():
def main_transpose():
n = 16*48
from pymbolic import var
a, b, i, j = [var(s) for s in "abij"]
k = make_loop_kernel([
LoopDimension("i", n),
LoopDimension("j", n),
], [
(b[i+n*j], a[j+n*i])
])
gen_kwargs = {
"min_threads": 128,
"min_blocks": 32,
}
if True and HAVE_CUDA:
if HAVE_CUDA:
a = curandom.rand((n, n))
b = gpuarray.empty_like(a)
def launcher(grid, kernel, texref_lookup):
a.bind_to_texref_ext(texref_lookup["a"])
kernel.prepared_call(grid, b.gpudata)
drive_timing_run(
generate_all_kernels(k, **gen_kwargs),
launcher, 0)
else:
show_kernel_codes(generate_all_kernels(k, **gen_kwargs))
if __name__ == "__main__":
# make sure that import failures get reported, instead of skipping the
# tests.
import pyopencl as cl
import sys
if len(sys.argv) > 1:
exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment