diff --git a/test.py b/test.py index 823f50d0a426927fde5eaccb622b6941b58beff7..4ab9363d28f11a7bdf2f49ef8e5ae63bce73515c 100644 --- a/test.py +++ b/test.py @@ -47,6 +47,75 @@ def test_compute_flux_derivatives(ctx_factory): metric_jacobians=metric_jacobians) +def test_compute_flux_derivatives_gpu(ctx_factory): + logging.basicConfig(level="INFO") + + queue = f.get_queue(ctx_factory) + + ndim = 3 + nvars = 5 + nx = 10 + ny = 10 + nz = 10 + + states = f.random_array(nvars, nx+6, ny+6, nz+6) + fluxes = f.random_array(nvars, ndim, nx+6, ny+6, nz+6) + metrics = f.random_array(ndim, ndim, nx+6, ny+6, nz+6) + metric_jacobians = f.random_array(nx+6, ny+6, nz+6) + + prg = f.prg + + cfd = prg["compute_flux_derivatives"] + + cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0") + + cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized", + lp.AddressSpace.GLOBAL) + cfd = lp.set_temporary_scope(cfd, "generalized_fluxes", + lp.AddressSpace.GLOBAL) + cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp", + lp.AddressSpace.GLOBAL) + + for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6"]: + cfd = lp.split_iname(cfd, "i"+suffix, 16, + outer_tag="g.0", inner_tag="l.0") + cfd = lp.split_iname(cfd, "j"+suffix, 16, + outer_tag="g.1", inner_tag="l.1") + + for var_name in ["delta_xi", "delta_eta", "delta_zeta"]: + cfd = lp.assignment_to_subst(cfd, var_name) + + cfd = lp.add_barrier(cfd, "tag:to_generalized", "tag:flux_x_compute") + cfd = lp.add_barrier(cfd, "tag:flux_x_compute", "tag:flux_x_diff") + cfd = lp.add_barrier(cfd, "tag:flux_x_diff", "tag:flux_y_compute") + cfd = lp.add_barrier(cfd, "tag:flux_y_compute", "tag:flux_y_diff") + cfd = lp.add_barrier(cfd, "tag:flux_y_diff", "tag:flux_z_compute") + cfd = lp.add_barrier(cfd, "tag:flux_z_compute", "tag:flux_z_diff") + cfd = lp.add_barrier(cfd, "tag:flux_z_diff", "tag:from_generalized") + + prg = prg.with_kernel(cfd) + + #prg = lp.inline_callable_kernel(prg, "convert_to_generalized") + #prg = lp.inline_callable_kernel(prg, "convert_from_generalized") + + if 0: + print(prg["convert_to_generalized_frozen"]) + 1/0 + + flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6, + nz+6), dtype=np.float32, order="F") + + if 0: + with open("gen-code.cl", "w") as outf: + outf.write(lp.generate_code_v2(prg).device_code()) + + prg(queue, nvars=nvars, ndim=ndim, + states=states, fluxes=fluxes, metrics=metrics, + metric_jacobians=metric_jacobians, + flux_derivatives=flux_derivatives_dev) + return flux_derivatives_dev.get() + + # This lets you run 'python test.py test_case(cl._csc)' without pytest. if __name__ == "__main__":