import numpy as np

import fixtures


class FluxDerivativeParams:
    def __init__(self, nvars, ndim, nx, ny, nz):
        self.nvars = nvars
        self.ndim = ndim
        self.nx = nx
        self.ny = ny
        self.nz = nz
        self.nhalo = 6


class FluxDerivativeArrays:
    def __init__(self, states, fluxes, metrics, metric_jacobians):
        self.states = states
        self.fluxes = fluxes
        self.metrics = metrics
        self.metric_jacobians = metric_jacobians


def random_array(*shape):
    return np.random.random_sample(shape).astype(np.float32).copy(order="F")


def random_flux_derivative_arrays(params):
    nvars = params.nvars
    ndim = params.ndim
    nx_halo = params.nx + params.nhalo
    ny_halo = params.ny + params.nhalo
    nz_halo = params.nz + params.nhalo

    states = random_array(nvars, nx_halo, ny_halo, nz_halo)
    fluxes = random_array(nvars, ndim, nx_halo, ny_halo, nz_halo)
    metrics = random_array(ndim, ndim, nx_halo, ny_halo, nz_halo)
    metric_jacobians = random_array(nx_halo, ny_halo, nz_halo)

    return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)


def random_flux_derivative_arrays_on_device(ctx_factory, params):
    queue = fixtures.get_queue(ctx_factory)

    nvars = params.nvars
    ndim = params.ndim
    nx_halo = params.nx + params.nhalo
    ny_halo = params.ny + params.nhalo
    nz_halo = params.nz + params.nhalo

    states = fixtures.f_array(queue, nvars, nx_halo, ny_halo, nz_halo)
    fluxes = fixtures.f_array(queue, nvars, ndim, nx_halo, ny_halo, nz_halo)
    metrics = fixtures.f_array(queue, ndim, ndim, nx_halo, ny_halo, nz_halo)
    metric_jacobians = fixtures.f_array(queue, nx_halo, ny_halo, nz_halo)

    return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)
