diff --git a/examples/matrix-ops.py b/examples/matrix-ops.py index a317fa532c2a6ddea7f651539d4ca13b777abdd2..1334bc5da7981dd4b6b70c73990e69de71137ff1 100644 --- a/examples/matrix-ops.py +++ b/examples/matrix-ops.py @@ -110,6 +110,65 @@ def plain_matrix_mul(ctx_factory=cl.create_some_context): + +def image_matrix_mul(ctx_factory=cl.create_some_context): + dtype = np.float32 + ctx = ctx_factory() + order = "C" + queue = cl.CommandQueue(ctx, + properties=cl.command_queue_properties.PROFILING_ENABLE) + + n = 16*100 + from pymbolic import var + a, b, c, i, j, k, n_sym = [var(s) for s in "abcijkn"] + + knl = lp.LoopKernel(ctx.devices[0], + "{[i,j,k]: 0<=i,j,k<%d}" % n, + [ + (c[i, j], a[i, k]*b[k, j]) + ], + [ + lp.ImageArg("a", dtype, 2), + lp.ImageArg("b", dtype, 2), + #lp.ArrayArg("a", dtype, shape=(n, n), order=order), + #lp.ArrayArg("b", dtype, shape=(n, n), order=order), + lp.ArrayArg("c", dtype, shape=(n, n), order=order), + ], + name="matmul") + + knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0", inner_tag="l.1") + knl = lp.split_dimension(knl, "j", 16, outer_tag="g.1", inner_tag="l.0") + knl = lp.split_dimension(knl, "k", 32) + # conflict-free + knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"]) + knl = lp.add_prefetch(knl, 'b', ["j_inner", "k_inner"]) + assert knl.get_invalid_reason() is None + + kernel_gen = (lp.insert_register_prefetches(knl) + for knl in lp.generate_loop_schedules(knl)) + + a = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order) + b = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order) + c = cl_array.empty_like(a) + refsol = np.dot(a.get(), b.get()) + a_img = cl.image_from_array(ctx, a.get(), 1) + b_img = cl.image_from_array(ctx, b.get(), 1) + + def launcher(kernel, gsize, lsize, check): + evt = kernel(queue, gsize(), lsize(), a_img, b_img, c.data, + g_times_l=True) + + if check: + check_error(refsol, c.get()) + + return evt + + lp.drive_timing_run(kernel_gen, queue, launcher, 2*n**3, + options=FAST_OPTIONS + ["-cl-nv-verbose"], + force_rebuild=True) + + + def dg_matrix_mul(ctx_factory=cl.create_some_context): dtype = np.float32 ctx = ctx_factory()