diff --git a/kernel_fixtures.py b/kernel_fixtures.py
index 278d5f7d4f69a565f5083a49a23d3dfa00c5701b..c2b9f30b09246d719ac46d6c91edd1412d19ce6a 100644
--- a/kernel_fixtures.py
+++ b/kernel_fixtures.py
@@ -40,24 +40,3 @@ def compute_flux_derivatives(queue, prg, params, arrays):
             flux_derivatives=flux_derivatives_dev)
 
     return flux_derivatives_dev.get()
-
-def compute_flux_derivatives_gpu(ctx_factory, params, arrays):
-    prg = transform.get_gpu_transformed_weno()
-
-    queue = device.get_queue(ctx_factory)
-
-    flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
-        params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")
-
-    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
-
-    if 1:
-        with open("gen-code.cl", "w") as outf:
-            outf.write(lp.generate_code_v2(prg).device_code())
-
-    prg = lp.set_options(prg, no_numpy=True)
-
-    prg(queue, nvars=params.nvars, ndim=params.ndim,
-            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
-            metric_jacobians=arrays.metric_jacobians,
-            flux_derivatives=flux_derivatives_dev)
diff --git a/test.py b/test.py
index 9e1ff7f43b1a0c9821cc242c5e534af9b1da9e9b..8e883b84b2671f75d1dcd2a4988cd3f1abe36196 100644
--- a/test.py
+++ b/test.py
@@ -36,12 +36,16 @@ def test_compute_flux_derivatives(ctx_factory):
     kernel.compute_flux_derivatives(queue, prg, params, arrays)
 
 
-@pytest.mark.skip("slow")
 def test_compute_flux_derivatives_gpu(ctx_factory):
+    queue = device.get_queue(ctx_factory)
+    prg = program.get_weno()
+    prg = transform.weno_for_gpu(prg)
+    prg = transform.compute_flux_derivative_gpu(queue, prg)
+
     params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
     arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params)
 
-    kernel.compute_flux_derivatives_gpu(ctx_factory, params, arrays)
+    kernel.compute_flux_derivatives(queue, prg, params, arrays)
 
 
 # This lets you run 'python test.py test_case(cl._csc)' without pytest.
diff --git a/transform_fixtures.py b/transform_fixtures.py
index 7b29fbcba3352432e5452f4a7a98836d2254cef7..a985e97f60e58fdcb48bd2f9e2b08dba35b6be2b 100644
--- a/transform_fixtures.py
+++ b/transform_fixtures.py
@@ -16,6 +16,17 @@ def compute_flux_derivative_basic(prg):
     return prg.with_kernel(cfd)
 
 
+def compute_flux_derivative_gpu(queue, prg):
+    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
+
+    if 1:
+        with open("gen-code.cl", "w") as outf:
+            outf.write(lp.generate_code_v2(prg).device_code())
+
+    prg = lp.set_options(prg, no_numpy=True)
+    return prg
+
+
 def weno_for_gpu(prg):
     cfd = prg["compute_flux_derivatives"]