import numpy as np
import pyopencl as cl

import fixtures


class FluxDerivativeParams:
    def __init__(self, nvars, ndim, nx, ny, nz):
        self.nvars = nvars
        self.ndim = ndim

        self.nx = nx
        self.ny = ny
        self.nz = nz

        self.nhalo = 3
        self.nx_halo = self.nx + 2*self.nhalo
        self.ny_halo = self.ny + 2*self.nhalo
        self.nz_halo = self.nz + 2*self.nhalo

    def state_bounds(self):
        return self.nvars, self.nx_halo, self.ny_halo, self.nz_halo

    def flux_bounds(self):
        return self.nvars, self.ndim, self.nx_halo, self.ny_halo, self.nz_halo

    def metric_bounds(self):
        return self.ndim, self.ndim, self.nx_halo, self.ny_halo, self.nz_halo

    def jacobian_bounds(self):
        return self.nx_halo, self.ny_halo, self.nz_halo


class FluxDerivativeArrays:
    def __init__(self, states, fluxes, metrics, metric_jacobians):
        self.states = states
        self.fluxes = fluxes
        self.metrics = metrics
        self.metric_jacobians = metric_jacobians


def random_array(*shape):
    return np.random.random_sample(shape).astype(np.float32).copy(order="F")


def random_array_on_device(queue, *shape):
    ary = np.random.random_sample(shape).astype(np.float32).copy(order="F")
    return cl.array.to_device(queue, ary)


def random_flux_derivative_arrays(params):
    states = random_array(*params.state_bounds())
    fluxes = random_array(*params.flux_bounds())
    metrics = random_array(*params.metric_bounds())
    metric_jacobians = random_array(*params.jacobian_bounds())

    return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)


def random_flux_derivative_arrays_on_device(ctx_factory, params):
    queue = fixtures.get_queue(ctx_factory)

    states = random_array_on_device(queue, *params.state_bounds())
    fluxes = random_array_on_device(queue, *params.flux_bounds())
    metrics = random_array_on_device(queue, *params.metric_bounds())
    metric_jacobians = random_array_on_device(queue, *params.jacobian_bounds())

    return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)
