From 340b78d89b363f7013db3c27edffc2bdc125af9a Mon Sep 17 00:00:00 2001
From: "Timothy A. Smith" <tasmith4@illinois.edu>
Date: Mon, 27 May 2019 22:36:52 -0500
Subject: [PATCH] create new kernel_fixtures.py for direct interface with Loopy
 kernels

---
 fixtures.py        | 10 ----------
 kernel_fixtures.py | 15 +++++++++++++++
 test.py            |  6 ++++--
 3 files changed, 19 insertions(+), 12 deletions(-)
 create mode 100644 kernel_fixtures.py

diff --git a/fixtures.py b/fixtures.py
index d11810c..10d0f45 100644
--- a/fixtures.py
+++ b/fixtures.py
@@ -92,16 +92,6 @@ def f_array(queue, *shape):
     ary = np.random.random_sample(shape).astype(np.float32).copy(order="F")
     return cl.array.to_device(queue, ary)
 
-def mult_mat_vec(ctx_factory, alpha, a, b):
-    queue = get_queue(ctx_factory)
-
-    c_dev = cl.array.empty(queue, 10, dtype=np.float32)
-
-    prg = with_root_kernel(get_weno_program(), "mult_mat_vec")
-    prg(queue, a=a, b=b, c=c_dev, alpha=alpha)
-
-    return c_dev.get()
-
 def compute_flux_derivatives(ctx_factory,
         nvars, ndim, nx, ny, nz,
         states, fluxes, metrics, metric_jacobians):
diff --git a/kernel_fixtures.py b/kernel_fixtures.py
new file mode 100644
index 0000000..d480a5d
--- /dev/null
+++ b/kernel_fixtures.py
@@ -0,0 +1,15 @@
+import numpy as np
+import pyopencl as cl
+
+import fixtures
+
+def mult_mat_vec(ctx_factory, alpha, a, b):
+    queue = fixtures.get_queue(ctx_factory)
+
+    c_dev = cl.array.empty(queue, 10, dtype=np.float32)
+
+    prg = fixtures.with_root_kernel(fixtures.get_weno_program(), "mult_mat_vec")
+    prg(queue, a=a, b=b, c=c_dev, alpha=alpha)
+
+    return c_dev.get()
+
diff --git a/test.py b/test.py
index 7a8d406..555e953 100644
--- a/test.py
+++ b/test.py
@@ -17,17 +17,19 @@ from pyopencl.tools import (  # noqa
 import fixtures
 import comparison_fixtures as compare
 import setup_fixtures as setup
+import kernel_fixtures as kernel
+
 
 def test_matvec(ctx_factory):
     a = setup.random_array(10, 10)
     b = setup.random_array(10)
 
-    c = fixtures.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0)
+    c = kernel.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0)
 
     compare.arrays(a@b, c)
 
 
-#@pytest.mark.skip("slow")
+@pytest.mark.skip("slow")
 def test_compute_flux_derivatives(ctx_factory):
     logging.basicConfig(level="INFO")
 
-- 
GitLab