import numpy as np
import pyopencl as cl
import pyopencl.array  # noqa

import device_fixtures as device


class RoeParams:
    def __init__(self, nvars, ndim, d):
        self.nvars = nvars
        self.ndim = ndim
        self.d = d

    def mat_bounds(self):
        return self.nvars, self.nvars

    def vec_bounds(self):
        return self.nvars


class FluxDerivativeParams:
    def __init__(self, nvars, ndim, nx, ny, nz):
        self.nvars = nvars
        self.ndim = ndim

        self.nx = nx
        self.ny = ny
        self.nz = nz

        self.nhalo = 3
        self.nx_halo = self.nx + 2*self.nhalo
        self.ny_halo = self.ny + 2*self.nhalo
        self.nz_halo = self.nz + 2*self.nhalo

    def state_bounds(self):
        return self.nvars, self.nx_halo, self.ny_halo, self.nz_halo

    def flux_bounds(self):
        return self.nvars, self.ndim, self.nx_halo, self.ny_halo, self.nz_halo

    def metric_bounds(self):
        return self.ndim, self.ndim, self.nx_halo, self.ny_halo, self.nz_halo

    def jacobian_bounds(self):
        return self.nx_halo, self.ny_halo, self.nz_halo


class FluxDerivativeArrays:
    def __init__(self, states, fluxes, metrics, metric_jacobians):
        self.states = states
        self.fluxes = fluxes
        self.metrics = metrics
        self.metric_jacobians = metric_jacobians


def roe_params(nvars, ndim, direction):
    dirs = {"x" : 1, "y" : 2, "z" : 3}
    return RoeParams(nvars, ndim, dirs[direction])


def flux_derivative_params(nvars, ndim, n):
    return FluxDerivativeParams(nvars, ndim, n, n, n)


def empty_array_on_device(queue, shape):
    return cl.array.empty(queue, shape, dtype=np.float32, order="F")


def identity(n):
    return np.identity(n).astype(np.float32).copy(order="F")


def random_array(*shape):
    return np.random.random_sample(shape).astype(np.float32).copy(order="F")


def random_array_on_device(queue, *shape):
    return cl.array.to_device(queue, random_array(*shape))


def random_flux_derivative_arrays(params):
    states = random_array(*params.state_bounds())
    fluxes = random_array(*params.flux_bounds())
    metrics = random_array(*params.metric_bounds())
    metric_jacobians = random_array(*params.jacobian_bounds())

    return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)


def random_flux_derivative_arrays_on_device(ctx_factory, params):
    queue = device.get_queue(ctx_factory)

    states = random_array_on_device(queue, *params.state_bounds())
    fluxes = random_array_on_device(queue, *params.flux_bounds())
    metrics = random_array_on_device(queue, *params.metric_bounds())
    metric_jacobians = random_array_on_device(queue, *params.jacobian_bounds())

    return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)


def arrays_from_string(string_arrays):
    return split_map_to_list(string_arrays, array_from_string, ":")


def array_from_string(string_array):
    if ";" not in string_array:
        if "," not in string_array:
            array = array_from_string_1d(string_array)
        else:
            array = array_from_string_2d(string_array)
    else:
        array = array_from_string_3d(string_array)
    return array.copy(order="F")


def array_from_string_3d(string_array):
    if string_array[0] == ";":
        return array_from_string_1d(string_array[1:]).reshape((-1, 1, 1))
    else:
        return np.array(split_map_to_list(string_array, array_from_string_2d, ";"))


def array_from_string_2d(string_array):
    if string_array[0] == ",":
        return array_from_string_1d(string_array[1:]).reshape((-1, 1))
    else:
        return np.array(split_map_to_list(string_array, array_from_string_1d, ","))


def array_from_string_1d(string_array):
    if string_array[0] == "i":
        return np.array(split_map_to_list(string_array[1:], int, " "))
    else:
        return np.array(split_map_to_list(string_array, float, " "), dtype=np.float32)


def split_map_to_list(string, map_func, splitter):
    return list(map(map_func, string.split(splitter)))
