From 319b4d98cfcab6a92789197198347dad3acc54a9 Mon Sep 17 00:00:00 2001
From: "Timothy A. Smith" <tasmith4@illinois.edu>
Date: Tue, 28 May 2019 23:33:07 -0500
Subject: [PATCH] take transforms out of compute_flux_derivatives

---
 kernel_fixtures.py    | 19 ++-----------------
 setup_fixtures.py     |  4 ++++
 test.py               | 12 +++++++++---
 transform_fixtures.py | 15 +++++++++++++++
 4 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/kernel_fixtures.py b/kernel_fixtures.py
index 36b2e3c..58d19bc 100644
--- a/kernel_fixtures.py
+++ b/kernel_fixtures.py
@@ -29,23 +29,7 @@ def mult_mat_vec(queue, prg, alpha, a, b):
     return c_dev.get()
 
 
-def compute_flux_derivatives(ctx_factory, params, arrays):
-    queue = device.get_queue(ctx_factory)
-
-    prg = program.get_weno()
-    cfd = prg["compute_flux_derivatives"]
-
-    cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0")
-
-    cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized",
-            lp.AddressSpace.GLOBAL)
-    cfd = lp.set_temporary_scope(cfd, "generalized_fluxes",
-            lp.AddressSpace.GLOBAL)
-    cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp",
-            lp.AddressSpace.GLOBAL)
-
-    prg = prg.with_kernel(cfd)
-
+def compute_flux_derivatives(queue, prg, params, arrays):
     flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
         params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")
 
@@ -53,6 +37,7 @@ def compute_flux_derivatives(ctx_factory, params, arrays):
             states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
             metric_jacobians=arrays.metric_jacobians,
             flux_derivatives=flux_derivatives_dev)
+
     return flux_derivatives_dev.get()
 
 def compute_flux_derivatives_gpu(ctx_factory, params, arrays):
diff --git a/setup_fixtures.py b/setup_fixtures.py
index eaaa63e..add7fea 100644
--- a/setup_fixtures.py
+++ b/setup_fixtures.py
@@ -39,6 +39,10 @@ class FluxDerivativeArrays:
         self.metric_jacobians = metric_jacobians
 
 
+def flux_derivative_params(nvars, ndim, n):
+    return FluxDerivativeParams(nvars, ndim, n, n, n)
+
+
 def random_array(*shape):
     return np.random.random_sample(shape).astype(np.float32).copy(order="F")
 
diff --git a/test.py b/test.py
index 821db6c..9e1ff7f 100644
--- a/test.py
+++ b/test.py
@@ -8,6 +8,7 @@ from pyopencl.tools import (  # noqa
 
 import device_fixtures as device
 import program_fixtures as program
+import transform_fixtures as transform
 import setup_fixtures as setup
 import kernel_fixtures as kernel
 import comparison_fixtures as compare
@@ -25,14 +26,19 @@ def test_matvec(ctx_factory):
 
 
 def test_compute_flux_derivatives(ctx_factory):
-    params = setup.FluxDerivativeParams(ndim=3, nvars=5, nx=10, ny=10, nz=10)
+    queue = device.get_queue(ctx_factory)
+    prg = program.get_weno()
+    prg = transform.compute_flux_derivative_basic(prg)
+
+    params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
     arrays = setup.random_flux_derivative_arrays(params)
 
-    kernel.compute_flux_derivatives(ctx_factory, params, arrays)
+    kernel.compute_flux_derivatives(queue, prg, params, arrays)
 
 
+@pytest.mark.skip("slow")
 def test_compute_flux_derivatives_gpu(ctx_factory):
-    params = setup.FluxDerivativeParams(ndim=3, nvars=5, nx=10, ny=10, nz=10)
+    params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
     arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params)
 
     kernel.compute_flux_derivatives_gpu(ctx_factory, params, arrays)
diff --git a/transform_fixtures.py b/transform_fixtures.py
index 33849bd..7b29fbc 100644
--- a/transform_fixtures.py
+++ b/transform_fixtures.py
@@ -1,6 +1,21 @@
 import loopy as lp
 
 
+def compute_flux_derivative_basic(prg):
+    cfd = prg["compute_flux_derivatives"]
+
+    cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0")
+
+    cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized",
+            lp.AddressSpace.GLOBAL)
+    cfd = lp.set_temporary_scope(cfd, "generalized_fluxes",
+            lp.AddressSpace.GLOBAL)
+    cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp",
+            lp.AddressSpace.GLOBAL)
+
+    return prg.with_kernel(cfd)
+
+
 def weno_for_gpu(prg):
     cfd = prg["compute_flux_derivatives"]
 
-- 
GitLab