diff --git a/WENO.F90 b/WENO.F90 index 995d3b7d86fcb177684c4e97071db6d5b6b249ec..6bfad5c1d58ed39569fb0a308a426215d35a13e8 100644 --- a/WENO.F90 +++ b/WENO.F90 @@ -951,6 +951,20 @@ end subroutine ! ! prg = lp.parse_fortran(lp.c_preprocess(SOURCE), FILENAME) ! prg = lp.fix_parameters(prg, ndim=3, nvars=5, _remove=False) +! +! cfd = prg["compute_flux_derivatives"] +! +! cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0") +! +! cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized", +! lp.AddressSpace.GLOBAL) +! cfd = lp.set_temporary_scope(cfd, "generalized_fluxes", +! lp.AddressSpace.GLOBAL) +! cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp", +! lp.AddressSpace.GLOBAL) +! +! prg = prg.with_kernel(cfd) +! ! RESULT = prg ! !$loopy end diff --git a/benchmark.py b/benchmark.py index 00034a7648dc8d7e31ecf85ad36040db2ba02ed5..5f48726fa6057f3931f181174b442fa4a989c735 100644 --- a/benchmark.py +++ b/benchmark.py @@ -14,18 +14,22 @@ from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) -import device_fixtures as device -import program_fixtures as program -import transform_fixtures as transform -import setup_fixtures as setup +from utilities import * -def benchmark_compute_flux_derivatives_gpu(ctx_factory): + +def benchmark_compute_flux_derivatives_gpu(ctx_factory, write_code=False): logging.basicConfig(level="INFO") - prg = program.get_weno() - prg = transform.weno_for_gpu(prg) + prg = get_weno_program() + prg = transform_weno_for_gpu(prg) - queue = device.get_queue(ctx_factory) + queue = get_queue(ctx_factory) + prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) + prg = lp.set_options(prg, no_numpy=True) + prg = lp.set_options(prg, ignore_boostable_into=True) + #prg = lp.set_options(prg, write_wrapper=True) + #op_map = lp.get_op_map(prg, count_redundant_work=False) + #print(op_map) ndim = 3 nvars = 5 @@ -35,27 +39,16 @@ def benchmark_compute_flux_derivatives_gpu(ctx_factory): nz = n print("ARRAY GEN") - states = setup.random_array_on_device(queue, nvars, nx+6, ny+6, nz+6) - fluxes = setup.random_array_on_device(queue, nvars, ndim, nx+6, ny+6, nz+6) - metrics = setup.random_array_on_device(queue, ndim, ndim, nx+6, ny+6, nz+6) - metric_jacobians = setup.random_array_on_device(queue, nx+6, ny+6, nz+6) + states = random_array_on_device(queue, nvars, nx+6, ny+6, nz+6) + fluxes = random_array_on_device(queue, nvars, ndim, nx+6, ny+6, nz+6) + metrics = random_array_on_device(queue, ndim, ndim, nx+6, ny+6, nz+6) + metric_jacobians = random_array_on_device(queue, nx+6, ny+6, nz+6) print("END ARRAY GEN") - flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6, - nz+6), dtype=np.float32, order="F") - - prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) + flux_derivatives_dev = empty_array_on_device(queue, nvars, ndim, nx+6, ny+6, nz+6) - if 0: - with open("gen-code.cl", "w") as outf: - outf.write(lp.generate_code_v2(prg).device_code()) - - prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) - prg = lp.set_options(prg, ignore_boostable_into=True) - prg = lp.set_options(prg, no_numpy=True) - #prg = lp.set_options(prg, write_wrapper=True) - #op_map = lp.get_op_map(prg, count_redundant_work=False) - #print(op_map) + if write_code: + write_to_cl(prg) allocator = pyopencl.tools.MemoryPool(pyopencl.tools.ImmediateAllocator(queue)) diff --git a/comparison_fixtures.py b/comparison_fixtures.py deleted file mode 100644 index 04c743236fc597b1ec48ed24d5974360e6da6c7e..0000000000000000000000000000000000000000 --- a/comparison_fixtures.py +++ /dev/null @@ -1,20 +0,0 @@ -import numpy as np -from pytest import approx - - -def arrays(a, b): - assert a == approx(b) - - -def roe_identity(states, R, Rinv): - dState = states[:,1] - states[:,0] - arrays(R@(Rinv@dState), dState) - - -def roe_property(states, fluxes, R, Rinv, lam): - dState = states[:,1] - states[:,0] - dFlux = fluxes[:,1] - fluxes[:,0] - - temp = Rinv@dState - temp = np.multiply(lam, temp) - arrays(R@temp, dFlux) diff --git a/device_fixtures.py b/device_fixtures.py deleted file mode 100644 index d0dbc5929dc744a0fed6354f3aa8aa34897bf57c..0000000000000000000000000000000000000000 --- a/device_fixtures.py +++ /dev/null @@ -1,15 +0,0 @@ -import pyopencl as cl - - -_QUEUE = [] - - -def get_queue(ctx_factory): - if not _QUEUE: - setup_queue(ctx_factory) - return _QUEUE[0] - - -def setup_queue(ctx_factory): - ctx = ctx_factory() - _QUEUE.append(cl.CommandQueue(ctx)) diff --git a/kernel_fixtures.py b/kernel_fixtures.py deleted file mode 100644 index 7f3dff465c30ab8fc9101ab092ef4c134bcfa951..0000000000000000000000000000000000000000 --- a/kernel_fixtures.py +++ /dev/null @@ -1,49 +0,0 @@ -import loopy as lp # noqa - -import setup_fixtures as setup - - -def with_root_kernel(prg, root_name): - # FIXME This is a little less beautiful than it could be - new_prg = prg.copy(name=root_name) - for name in prg: - clbl = new_prg[name] - if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host: - new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False)) - - new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True)) - return new_prg - - -def roe_eigensystem(queue, prg, params, states, metrics_frozen): - R_dev = setup.empty_array_on_device(queue, params.mat_bounds()) - Rinv_dev = setup.empty_array_on_device(queue, params.mat_bounds()) - lam_dev = setup.empty_array_on_device(queue, params.vec_bounds()) - - prg = with_root_kernel(prg, "roe_eigensystem") - prg(queue, nvars=params.nvars, ndim=params.ndim, d=params.d, - states=states, metrics_frozen=metrics_frozen, - R=R_dev, R_inv=Rinv_dev, lambda_roe=lam_dev) - - return R_dev.get(), Rinv_dev.get(), lam_dev.get() - - -def mult_mat_vec(queue, prg, alpha, a, b): - c_dev = setup.empty_array_on_device(queue, b.shape) - - prg = with_root_kernel(prg, "mult_mat_vec") - prg(queue, a=a, b=b, c=c_dev, alpha=alpha) - - return c_dev.get() - - -def compute_flux_derivatives(queue, prg, params, arrays): - flux_derivatives_dev = setup.empty_array_on_device(queue, (params.nvars, params.ndim, - params.nx_halo, params.ny_halo, params.nz_halo)) - - prg(queue, nvars=params.nvars, ndim=params.ndim, - states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics, - metric_jacobians=arrays.metric_jacobians, - flux_derivatives=flux_derivatives_dev) - - return flux_derivatives_dev.get() diff --git a/program_fixtures.py b/program_fixtures.py deleted file mode 100644 index 0f50ff1c364bc59f201fb1e551a34fde1602c6f2..0000000000000000000000000000000000000000 --- a/program_fixtures.py +++ /dev/null @@ -1,20 +0,0 @@ -import loopy as lp - - -_WENO_PRG = [] - - -def parse_weno(): - fn = "WENO.F90" - - with open(fn, "r") as infile: - infile_content = infile.read() - - prg = lp.parse_transformed_fortran(infile_content, filename=fn) - _WENO_PRG.append(prg) - - -def get_weno(): - if not _WENO_PRG: - parse_weno() - return _WENO_PRG[0] diff --git a/setup_fixtures.py b/setup_fixtures.py deleted file mode 100644 index 6f1debcb3e2bc1b8acb16f45d58bd3a16771a25a..0000000000000000000000000000000000000000 --- a/setup_fixtures.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -import pyopencl as cl -import pyopencl.array # noqa - -import device_fixtures as device - - -class RoeParams: - def __init__(self, nvars, ndim, d): - self.nvars = nvars - self.ndim = ndim - self.d = d - - def mat_bounds(self): - return self.nvars, self.nvars - - def vec_bounds(self): - return self.nvars - - -class FluxDerivativeParams: - def __init__(self, nvars, ndim, nx, ny, nz): - self.nvars = nvars - self.ndim = ndim - - self.nx = nx - self.ny = ny - self.nz = nz - - self.nhalo = 3 - self.nx_halo = self.nx + 2*self.nhalo - self.ny_halo = self.ny + 2*self.nhalo - self.nz_halo = self.nz + 2*self.nhalo - - def state_bounds(self): - return self.nvars, self.nx_halo, self.ny_halo, self.nz_halo - - def flux_bounds(self): - return self.nvars, self.ndim, self.nx_halo, self.ny_halo, self.nz_halo - - def metric_bounds(self): - return self.ndim, self.ndim, self.nx_halo, self.ny_halo, self.nz_halo - - def jacobian_bounds(self): - return self.nx_halo, self.ny_halo, self.nz_halo - - -class FluxDerivativeArrays: - def __init__(self, states, fluxes, metrics, metric_jacobians): - self.states = states - self.fluxes = fluxes - self.metrics = metrics - self.metric_jacobians = metric_jacobians - - -def roe_params(nvars, ndim, direction): - dirs = {"x" : 1, "y" : 2, "z" : 3} - return RoeParams(nvars, ndim, dirs[direction]) - - -def flux_derivative_params(nvars, ndim, n): - return FluxDerivativeParams(nvars, ndim, n, n, n) - - -def empty_array_on_device(queue, shape): - return cl.array.empty(queue, shape, dtype=np.float32, order="F") - - -def identity(n): - return np.identity(n).astype(np.float32).copy(order="F") - - -def random_array(*shape): - return np.random.random_sample(shape).astype(np.float32).copy(order="F") - - -def random_array_on_device(queue, *shape): - return cl.array.to_device(queue, random_array(*shape)) - - -def random_flux_derivative_arrays(params): - states = random_array(*params.state_bounds()) - fluxes = random_array(*params.flux_bounds()) - metrics = random_array(*params.metric_bounds()) - metric_jacobians = random_array(*params.jacobian_bounds()) - - return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians) - - -def random_flux_derivative_arrays_on_device(ctx_factory, params): - queue = device.get_queue(ctx_factory) - - states = random_array_on_device(queue, *params.state_bounds()) - fluxes = random_array_on_device(queue, *params.flux_bounds()) - metrics = random_array_on_device(queue, *params.metric_bounds()) - metric_jacobians = random_array_on_device(queue, *params.jacobian_bounds()) - - return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians) - - -def arrays_from_string(string_arrays): - return split_map_to_list(string_arrays, array_from_string, ":") - - -def array_from_string(string_array): - if ";" not in string_array: - if "," not in string_array: - array = array_from_string_1d(string_array) - else: - array = array_from_string_2d(string_array) - else: - array = array_from_string_3d(string_array) - return array.copy(order="F") - - -def array_from_string_3d(string_array): - if string_array[0] == ";": - return array_from_string_1d(string_array[1:]).reshape((-1, 1, 1)) - else: - return np.array(split_map_to_list(string_array, array_from_string_2d, ";")) - - -def array_from_string_2d(string_array): - if string_array[0] == ",": - return array_from_string_1d(string_array[1:]).reshape((-1, 1)) - else: - return np.array(split_map_to_list(string_array, array_from_string_1d, ",")) - - -def array_from_string_1d(string_array): - if string_array[0] == "i": - return np.array(split_map_to_list(string_array[1:], int, " ")) - else: - return np.array(split_map_to_list(string_array, float, " "), dtype=np.float32) - - -def split_map_to_list(string, map_func, splitter): - return list(map(map_func, string.split(splitter))) diff --git a/test.py b/test.py index 924171f561ad96e1eec9984bf6f46c7d8cae589a..e9727348f9ce65c2c8476ed76ed117786237f007 100644 --- a/test.py +++ b/test.py @@ -1,18 +1,22 @@ +import numpy as np +import numpy.linalg as la +import pyopencl as cl +import pyopencl.array # noqa +import pyopencl.tools # noqa +import pyopencl.clrandom # noqa +import loopy as lp # noqa + import sys import logging import pytest -import pyopencl as cl +from pytest import approx from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) -import device_fixtures as device -import program_fixtures as program -import transform_fixtures as transform -import setup_fixtures as setup -import kernel_fixtures as kernel -import comparison_fixtures as compare +from utilities import * + @pytest.mark.xfail @pytest.mark.parametrize("states_str,fluxes_str,direction", [ @@ -27,54 +31,103 @@ import comparison_fixtures as compare ("2 1,4 1,8 2,12 3,64 11", "12 3,24 3,48 6,75.2 10.6,403.2 37.8", "z") ]) def test_roe_uniform_grid(ctx_factory, states_str, fluxes_str, direction): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() + class RoeParams: + def __init__(self, nvars, ndim, d): + self.nvars = nvars + self.ndim = ndim + self.d = d + + def mat_bounds(self): + return self.nvars, self.nvars + + def vec_bound(self): + return self.nvars + + def setup_roe_params(nvars, ndim, direction): + dirs = {"x" : 1, "y" : 2, "z" : 3} + return RoeParams(nvars, ndim, dirs[direction]) + + def identity_matrix(n): + return np.identity(n).astype(np.float32).copy(order="F") + + def kernel_roe_eigensystem(queue, prg, params, states, metrics_frozen): + R_dev = empty_array_on_device(queue, *params.mat_bounds()) + Rinv_dev = empty_array_on_device(queue, *params.mat_bounds()) + lam_dev = empty_array_on_device(queue, params.vec_bound()) - params = setup.roe_params(nvars=5, ndim=3, direction=direction) - states = setup.array_from_string(states_str) - metrics_frozen = setup.identity(params.ndim) - R, Rinv, lam = kernel.roe_eigensystem(queue, prg, params, states, metrics_frozen) + prg = with_root_kernel(prg, "roe_eigensystem") + prg(queue, nvars=params.nvars, ndim=params.ndim, d=params.d, + states=states, metrics_frozen=metrics_frozen, + R=R_dev, R_inv=Rinv_dev, lambda_roe=lam_dev) - compare.roe_identity(states, R, Rinv) + return R_dev.get(), Rinv_dev.get(), lam_dev.get() - fluxes = setup.array_from_string(fluxes_str) - compare.roe_property(states, fluxes, R, Rinv, lam) + def check_roe_identity(states, R, Rinv): + dState = states[:,1] - states[:,0] + compare_arrays(R@(Rinv@dState), dState) + + def check_roe_property(states, fluxes, R, Rinv, lam): + dState = states[:,1] - states[:,0] + dFlux = fluxes[:,1] - fluxes[:,0] + + temp = Rinv@dState + temp = np.multiply(lam, temp) + compare_arrays(R@temp, dFlux) + + queue = get_queue(ctx_factory) + prg = get_weno_program() + + params = setup_roe_params(nvars=5, ndim=3, direction=direction) + states = array_from_string(states_str) + metrics_frozen = identity_matrix(params.ndim) + R, Rinv, lam = kernel_roe_eigensystem(queue, prg, params, states, metrics_frozen) + + check_roe_identity(states, R, Rinv) + + fluxes = array_from_string(fluxes_str) + check_roe_property(states, fluxes, R, Rinv, lam) def test_matvec(ctx_factory): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() + prg = get_weno_program() + queue = get_queue(ctx_factory) - a = setup.random_array(10, 10) - b = setup.random_array(10) + a = random_array_on_device(queue, 10, 10) + b = random_array_on_device(queue, 10) - c = kernel.mult_mat_vec(queue, prg, alpha=1.0, a=a, b=b) + c = empty_array_on_device(queue, 10) - compare.arrays(a@b, c) + prg = with_root_kernel(prg, "mult_mat_vec") + prg(queue, alpha=1.0, a=a, b=b, c=c) + + compare_arrays(a.get()@b.get(), c.get()) @pytest.mark.slow def test_compute_flux_derivatives(ctx_factory): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() - prg = transform.compute_flux_derivative_basic(prg) + prg = get_weno_program() - params = setup.flux_derivative_params(ndim=3, nvars=5, n=10) - arrays = setup.random_flux_derivative_arrays(params) + queue = get_queue(ctx_factory) + prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) - kernel.compute_flux_derivatives(queue, prg, params, arrays) + lp.auto_test_vs_ref(prg, ctx_factory(), warmup_rounds=1, + parameters=dict(ndim=3, nvars=5, nx=16, ny=16, nz=16)) @pytest.mark.slow -def test_compute_flux_derivatives_gpu(ctx_factory): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() - prg = transform.compute_flux_derivative_gpu(queue, prg) +def test_compute_flux_derivatives_gpu(ctx_factory, write_code=False): + prg = get_weno_program() + prg = transform_weno_for_gpu(prg) + + queue = get_queue(ctx_factory) + prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) + prg = lp.set_options(prg, no_numpy=True) - params = setup.flux_derivative_params(ndim=3, nvars=5, n=10) - arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params) + if write_code: + write_to_cl(prg) - kernel.compute_flux_derivatives(queue, prg, params, arrays) + lp.auto_test_vs_ref(prg, ctx_factory(), warmup_rounds=1, + parameters=dict(ndim=3, nvars=5, nx=16, ny=16, nz=16)) # This lets you run 'python test.py test_case(cl._csc)' without pytest. diff --git a/transform_fixtures.py b/transform_fixtures.py deleted file mode 100644 index f69581a045eae6d0289910fc635b2b289c0d5178..0000000000000000000000000000000000000000 --- a/transform_fixtures.py +++ /dev/null @@ -1,65 +0,0 @@ -import loopy as lp - - -def compute_flux_derivative_basic(prg): - cfd = prg["compute_flux_derivatives"] - - cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0") - - cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized", - lp.AddressSpace.GLOBAL) - cfd = lp.set_temporary_scope(cfd, "generalized_fluxes", - lp.AddressSpace.GLOBAL) - cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp", - lp.AddressSpace.GLOBAL) - - return prg.with_kernel(cfd) - - -def weno_for_gpu(prg): - prg = compute_flux_derivative_basic(prg) - - cfd = prg["compute_flux_derivatives"] - - for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6", "_7"]: - cfd = lp.split_iname(cfd, "i"+suffix, 16, - outer_tag="g.0", inner_tag="l.0") - cfd = lp.split_iname(cfd, "j"+suffix, 16, - outer_tag="g.1", inner_tag="l.1") - - for var_name in ["delta_xi", "delta_eta", "delta_zeta"]: - cfd = lp.assignment_to_subst(cfd, var_name) - - cfd = lp.add_barrier(cfd, "tag:to_generalized", "tag:flux_x_compute") - cfd = lp.add_barrier(cfd, "tag:flux_x_compute", "tag:flux_x_diff") - cfd = lp.add_barrier(cfd, "tag:flux_x_diff", "tag:flux_y_compute") - cfd = lp.add_barrier(cfd, "tag:flux_y_compute", "tag:flux_y_diff") - cfd = lp.add_barrier(cfd, "tag:flux_y_diff", "tag:flux_z_compute") - cfd = lp.add_barrier(cfd, "tag:flux_z_compute", "tag:flux_z_diff") - cfd = lp.add_barrier(cfd, "tag:flux_z_diff", "tag:from_generalized") - - prg = prg.with_kernel(cfd) - - # FIXME: These should work, but don't - # FIXME: Undo the hand-inlining in WENO.F90 - #prg = lp.inline_callable_kernel(prg, "convert_to_generalized") - #prg = lp.inline_callable_kernel(prg, "convert_from_generalized") - - if 0: - print(prg["convert_to_generalized_frozen"]) - 1/0 - - return prg - - -def compute_flux_derivative_gpu(queue, prg): - prg = weno_for_gpu(prg) - - prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) - - if 1: - with open("gen-code.cl", "w") as outf: - outf.write(lp.generate_code_v2(prg).device_code()) - - prg = lp.set_options(prg, no_numpy=True) - return prg diff --git a/utilities.py b/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..306c28ef4dc3cd331da5cca236badf7f0bbc315d --- /dev/null +++ b/utilities.py @@ -0,0 +1,150 @@ +import numpy as np +import numpy.linalg as la +import pyopencl as cl +import pyopencl.array # noqa +import pyopencl.tools # noqa +import pyopencl.clrandom # noqa +import loopy as lp # noqa +from pytest import approx + + +### Arrays ### + +def compare_arrays(a, b): + assert a == approx(b) + + +def random_array_on_device(queue, *shape): + ary = empty_array_on_device(queue, *shape) + cl.clrandom.fill_rand(ary) + return ary + + +def empty_array_on_device(queue, *shape): + return cl.array.empty(queue, shape, dtype=np.float32, order="F") + + +def arrays_from_string(string_arrays): + return split_map_to_list(string_arrays, array_from_string, ":") + + +def array_from_string(string_array): + def array_from_string_1d(string_array): + if string_array[0] == "i": + return np.array(split_map_to_list(string_array[1:], int, " ")) + else: + return np.array(split_map_to_list(string_array, float, " "), dtype=np.float32) + + def array_from_string_2d(string_array): + if string_array[0] == ",": + return array_from_string_1d(string_array[1:]).reshape((-1, 1)) + else: + return np.array(split_map_to_list(string_array, array_from_string_1d, ",")) + + def array_from_string_3d(string_array): + if string_array[0] == ";": + return array_from_string_1d(string_array[1:]).reshape((-1, 1, 1)) + else: + return np.array(split_map_to_list(string_array, array_from_string_2d, ";")) + + if ";" not in string_array: + if "," not in string_array: + array = array_from_string_1d(string_array) + else: + array = array_from_string_2d(string_array) + else: + array = array_from_string_3d(string_array) + return array.copy(order="F") + + +def split_map_to_list(string, map_func, splitter): + return list(map(map_func, string.split(splitter))) + + +### Device ### + +_QUEUE = [] + + +def get_queue(ctx_factory): + if not _QUEUE: + setup_queue(ctx_factory) + return _QUEUE[0] + + +def setup_queue(ctx_factory): + ctx = ctx_factory() + _QUEUE.append(cl.CommandQueue(ctx)) + + +### Program / Kernel ### + +_WENO_PRG = [] + + +def get_weno_program(): + if not _WENO_PRG: + parse_weno() + return _WENO_PRG[0] + + +def parse_weno(): + fn = "WENO.F90" + + with open(fn, "r") as infile: + infile_content = infile.read() + + prg = lp.parse_transformed_fortran(infile_content, filename=fn) + _WENO_PRG.append(prg) + + +def with_root_kernel(prg, root_name): + # FIXME This is a little less beautiful than it could be + new_prg = prg.copy(name=root_name) + for name in prg: + clbl = new_prg[name] + if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host: + new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False)) + + new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True)) + return new_prg + + +def transform_weno_for_gpu(prg, print_kernel=False): + cfd = prg["compute_flux_derivatives"] + + for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6", "_7"]: + cfd = lp.split_iname(cfd, "i"+suffix, 16, + outer_tag="g.0", inner_tag="l.0") + cfd = lp.split_iname(cfd, "j"+suffix, 16, + outer_tag="g.1", inner_tag="l.1") + + for var_name in ["delta_xi", "delta_eta", "delta_zeta"]: + cfd = lp.assignment_to_subst(cfd, var_name) + + cfd = lp.add_barrier(cfd, "tag:to_generalized", "tag:flux_x_compute") + cfd = lp.add_barrier(cfd, "tag:flux_x_compute", "tag:flux_x_diff") + cfd = lp.add_barrier(cfd, "tag:flux_x_diff", "tag:flux_y_compute") + cfd = lp.add_barrier(cfd, "tag:flux_y_compute", "tag:flux_y_diff") + cfd = lp.add_barrier(cfd, "tag:flux_y_diff", "tag:flux_z_compute") + cfd = lp.add_barrier(cfd, "tag:flux_z_compute", "tag:flux_z_diff") + cfd = lp.add_barrier(cfd, "tag:flux_z_diff", "tag:from_generalized") + + prg = prg.with_kernel(cfd) + + # FIXME: These should work, but don't + # FIXME: Undo the hand-inlining in WENO.F90 + #prg = lp.inline_callable_kernel(prg, "convert_to_generalized") + #prg = lp.inline_callable_kernel(prg, "convert_from_generalized") + + if print_kernel: + print(prg["convert_to_generalized_frozen"]) + 1/0 + + return prg + + +def write_to_cl(prg, outfilename="gen-code.cl"): + with open(outfilename, "w") as outf: + outf.write(lp.generate_code_v2(prg).device_code()) +