diff --git a/test.py b/test.py
index 4ab9363d28f11a7bdf2f49ef8e5ae63bce73515c..2446f40d8f33b00d92bbcfec0ad655d51a601f71 100644
--- a/test.py
+++ b/test.py
@@ -47,22 +47,7 @@ def test_compute_flux_derivatives(ctx_factory):
             metric_jacobians=metric_jacobians)
 
 
-def test_compute_flux_derivatives_gpu(ctx_factory):
-    logging.basicConfig(level="INFO")
-
-    queue = f.get_queue(ctx_factory)
-
-    ndim = 3
-    nvars = 5
-    nx = 10
-    ny = 10
-    nz = 10
-
-    states = f.random_array(nvars, nx+6, ny+6, nz+6)
-    fluxes = f.random_array(nvars, ndim, nx+6, ny+6, nz+6)
-    metrics = f.random_array(ndim, ndim, nx+6, ny+6, nz+6)
-    metric_jacobians = f.random_array(nx+6, ny+6, nz+6)
-
+def get_gpu_transformed_weno():
     prg = f.prg
 
     cfd = prg["compute_flux_derivatives"]
@@ -95,6 +80,8 @@ def test_compute_flux_derivatives_gpu(ctx_factory):
 
     prg = prg.with_kernel(cfd)
 
+    # FIXME: These should work, but don't
+    # FIXME: Undo the hand-inlining in WENO.F90
     #prg = lp.inline_callable_kernel(prg, "convert_to_generalized")
     #prg = lp.inline_callable_kernel(prg, "convert_from_generalized")
 
@@ -102,10 +89,31 @@ def test_compute_flux_derivatives_gpu(ctx_factory):
         print(prg["convert_to_generalized_frozen"])
         1/0
 
+    return prg
+
+
+def test_compute_flux_derivatives_gpu(ctx_factory):
+    logging.basicConfig(level="INFO")
+
+    prg = get_gpu_transformed_weno()
+
+    queue = f.get_queue(ctx_factory)
+
+    ndim = 3
+    nvars = 5
+    nx = 10
+    ny = 10
+    nz = 10
+
+    states = f.random_array(nvars, nx+6, ny+6, nz+6)
+    fluxes = f.random_array(nvars, ndim, nx+6, ny+6, nz+6)
+    metrics = f.random_array(ndim, ndim, nx+6, ny+6, nz+6)
+    metric_jacobians = f.random_array(nx+6, ny+6, nz+6)
+
     flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6,
         nz+6), dtype=np.float32, order="F")
 
-    if 0:
+    if 1:
         with open("gen-code.cl", "w") as outf:
             outf.write(lp.generate_code_v2(prg).device_code())
 
@@ -113,11 +121,66 @@ def test_compute_flux_derivatives_gpu(ctx_factory):
             states=states, fluxes=fluxes, metrics=metrics,
             metric_jacobians=metric_jacobians,
             flux_derivatives=flux_derivatives_dev)
-    return flux_derivatives_dev.get()
 
 
-# This lets you run 'python test.py test_case(cl._csc)' without pytest.
+def benchmark_compute_flux_derivatives_gpu(ctx_factory):
+    logging.basicConfig(level="INFO")
+
+    prg = get_gpu_transformed_weno()
+
+    queue = f.get_queue(ctx_factory)
+
+    ndim = 3
+    nvars = 5
+    n = 100
+    nx = n
+    ny = n
+    nz = n
+
+    states = f.random_array(nvars, nx+6, ny+6, nz+6)
+    fluxes = f.random_array(nvars, ndim, nx+6, ny+6, nz+6)
+    metrics = f.random_array(ndim, ndim, nx+6, ny+6, nz+6)
+    metric_jacobians = f.random_array(nx+6, ny+6, nz+6)
+
+    flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6,
+        nz+6), dtype=np.float32, order="F")
 
+    if 0:
+        with open("gen-code.cl", "w") as outf:
+            outf.write(lp.generate_code_v2(prg).device_code())
+
+    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
+    prg = lp.set_options(prg, ignore_boostable_into=True, write_wrapper=True)
+    #op_map = lp.get_op_map(prg, count_redundant_work=False)
+    #print(op_map)
+
+    from functools import partial
+    run = partial(prg, queue, nvars=nvars, ndim=ndim,
+            states=states, fluxes=fluxes, metrics=metrics,
+            metric_jacobians=metric_jacobians,
+            flux_derivatives=flux_derivatives_dev)
+
+    print("warmup")
+    for iwarmup_round in range(2):
+        run()
+
+    nrounds = 10
+
+    queue.finish()
+    print("timing")
+    from time import time
+    start = time()
+
+    for iround in range(nrounds):
+        run()
+
+    queue.finish()
+    one_round = (time() - start)/nrounds
+
+    print(f"DOFs/s: {n**3/one_round}, elapsed per round: {one_round} s")
+
+
+# This lets you run 'python test.py test_case(cl._csc)' without pytest.
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])