diff --git a/fixtures.py b/fixtures.py index f2e8f28508adf2c317c31df1796213914b3a0a63..257f58216af3fa307929dec6a4b79d34252ebfc5 100644 --- a/fixtures.py +++ b/fixtures.py @@ -1,5 +1,5 @@ import numpy as np -import numpy.linalg as la +import numpy.linalg as la # noqa import pyopencl as cl import pyopencl.array # noqa import pyopencl.clrandom # noqa @@ -65,12 +65,24 @@ class LoopyFixture: queue = self.get_queue(ctx_factory) - prg = lp.fix_parameters(self.prg, nx=nx, ny=ny, nz=nz, _remove=False) + prg = self.prg + cfd = prg["compute_flux_derivatives"] + + cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0") + + cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized", + lp.AddressSpace.GLOBAL) + cfd = lp.set_temporary_scope(cfd, "generalized_fluxes", + lp.AddressSpace.GLOBAL) + cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp", + lp.AddressSpace.GLOBAL) + + prg = prg.with_kernel(cfd) flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6, nz+6), dtype=np.float32, order="F") - prg(queue, nvars=nvars, ndim=ndim, nx=nx, ny=ny, nz=nz, + prg(queue, nvars=nvars, ndim=ndim, states=states, fluxes=fluxes, metrics=metrics, metric_jacobians=metric_jacobians, flux_derivatives=flux_derivatives_dev)