From 2268c61284eea12e9f1afa7771f849d32bd084a7 Mon Sep 17 00:00:00 2001
From: "Timothy A. Smith" <tasmith4@illinois.edu>
Date: Tue, 28 May 2019 23:45:23 -0500
Subject: [PATCH] replace kernel.compute_flux_derivatives_gpu with appropriate
 transform fixtures

---
 kernel_fixtures.py    | 21 ---------------------
 test.py               |  8 ++++++--
 transform_fixtures.py | 11 +++++++++++
 3 files changed, 17 insertions(+), 23 deletions(-)

diff --git a/kernel_fixtures.py b/kernel_fixtures.py
index 278d5f7..c2b9f30 100644
--- a/kernel_fixtures.py
+++ b/kernel_fixtures.py
@@ -40,24 +40,3 @@ def compute_flux_derivatives(queue, prg, params, arrays):
             flux_derivatives=flux_derivatives_dev)
 
     return flux_derivatives_dev.get()
-
-def compute_flux_derivatives_gpu(ctx_factory, params, arrays):
-    prg = transform.get_gpu_transformed_weno()
-
-    queue = device.get_queue(ctx_factory)
-
-    flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
-        params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")
-
-    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
-
-    if 1:
-        with open("gen-code.cl", "w") as outf:
-            outf.write(lp.generate_code_v2(prg).device_code())
-
-    prg = lp.set_options(prg, no_numpy=True)
-
-    prg(queue, nvars=params.nvars, ndim=params.ndim,
-            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
-            metric_jacobians=arrays.metric_jacobians,
-            flux_derivatives=flux_derivatives_dev)
diff --git a/test.py b/test.py
index 9e1ff7f..8e883b8 100644
--- a/test.py
+++ b/test.py
@@ -36,12 +36,16 @@ def test_compute_flux_derivatives(ctx_factory):
     kernel.compute_flux_derivatives(queue, prg, params, arrays)
 
 
-@pytest.mark.skip("slow")
 def test_compute_flux_derivatives_gpu(ctx_factory):
+    queue = device.get_queue(ctx_factory)
+    prg = program.get_weno()
+    prg = transform.weno_for_gpu(prg)
+    prg = transform.compute_flux_derivative_gpu(queue, prg)
+
     params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
     arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params)
 
-    kernel.compute_flux_derivatives_gpu(ctx_factory, params, arrays)
+    kernel.compute_flux_derivatives(queue, prg, params, arrays)
 
 
 # This lets you run 'python test.py test_case(cl._csc)' without pytest.
diff --git a/transform_fixtures.py b/transform_fixtures.py
index 7b29fbc..a985e97 100644
--- a/transform_fixtures.py
+++ b/transform_fixtures.py
@@ -16,6 +16,17 @@ def compute_flux_derivative_basic(prg):
     return prg.with_kernel(cfd)
 
 
+def compute_flux_derivative_gpu(queue, prg):
+    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
+
+    if 1:
+        with open("gen-code.cl", "w") as outf:
+            outf.write(lp.generate_code_v2(prg).device_code())
+
+    prg = lp.set_options(prg, no_numpy=True)
+    return prg
+
+
 def weno_for_gpu(prg):
     cfd = prg["compute_flux_derivatives"]
 
-- 
GitLab