diff --git a/fixtures.py b/fixtures.py index c5737db8b2db5c09648824b2c802722821dd958b..6a1e7919a864a34b7698900887f9e932a4a3b3d2 100644 --- a/fixtures.py +++ b/fixtures.py @@ -58,18 +58,6 @@ def get_gpu_transformed_weno(): -def with_root_kernel(prg, root_name): - # FIXME This is a little less beautiful than it could be - new_prg = prg.copy(name=root_name) - for name in prg: - clbl = new_prg[name] - if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host: - new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False)) - - new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True)) - return new_prg - - def get_weno_program(): if _WENO_PRG: return _WENO_PRG[0] diff --git a/kernel_fixtures.py b/kernel_fixtures.py index 610e4697c08309a0cba41cd8c8cbfd0d6fe2e29e..cbef8a36906754d41c15491bdf1763d08bb2d1b3 100644 --- a/kernel_fixtures.py +++ b/kernel_fixtures.py @@ -5,10 +5,22 @@ import loopy as lp # noqa import fixtures +def with_root_kernel(prg, root_name): + # FIXME This is a little less beautiful than it could be + new_prg = prg.copy(name=root_name) + for name in prg: + clbl = new_prg[name] + if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host: + new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False)) + + new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True)) + return new_prg + + def mult_mat_vec(queue, prg, alpha, a, b): c_dev = cl.array.empty(queue, 10, dtype=np.float32) - prg = fixtures.with_root_kernel(prg, "mult_mat_vec") + prg = with_root_kernel(prg, "mult_mat_vec") prg(queue, a=a, b=b, c=c_dev, alpha=alpha) return c_dev.get()