from __future__ import annotations


__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, Any, Literal, overload

import numpy as np
from numpy.typing import NDArray
from typing_extensions import override

from gmsh_interop.reader import (
    GmshElementBase,
    GmshMeshReceiverBase,
    GmshSimplexElementBase,
    GmshTensorProductElementBase,
)
from gmsh_interop.runner import (
    FileSource,
    LiteralSource,
    ScriptSource,
    ScriptWithFilesSource,
)


if TYPE_CHECKING:
    from collections.abc import Mapping, MutableSequence, Sequence

    from gmsh_interop.runner import GmshSource
    from modepy import Shape

    from meshmode.mesh import Mesh, MeshElementGroup


__doc__ = """

.. autoclass:: ScriptSource
.. autoclass:: FileSource
.. autoclass:: ScriptWithFilesSource

.. autofunction:: read_gmsh
.. autofunction:: generate_gmsh
.. autofunction:: from_meshpy
.. autofunction:: from_vertices_and_simplices
.. autofunction:: to_json

References
----------

.. class:: GmshSource

    See :mod:`gmsh_interop.runner`, also reexported here.

.. class:: IndexArray

    A :class:`numpy.ndarray` of integer dtype.
"""


# {{{ gmsh receiver

# NOTE: Keep in sync with definitions from gmsh_interop.reader
Point = NDArray[np.floating]
IndexArray = NDArray[np.integer]


class GmshMeshReceiver(GmshMeshReceiverBase):
    def __init__(self,
                mesh_construction_kwargs: Mapping[str, Any] | None = None) -> None:
        if mesh_construction_kwargs is None:
            mesh_construction_kwargs = {}

        # Use data fields similar to meshpy.triangle.MeshInfo and meshpy.tet.MeshInfo
        self.points: MutableSequence[Point | None] | Point = []
        self.element_vertices: MutableSequence[IndexArray | None] = []
        self.element_nodes: MutableSequence[IndexArray | None] = []
        self.element_types: MutableSequence[GmshElementBase | None] = []
        self.element_markers: MutableSequence[Sequence[int] | None] = []
        self.tags: list[tuple[str, int]] = []
        self.groups: MutableSequence[MeshElementGroup] | None = None
        self.gmsh_tag_index_to_mine: dict[int, int] = {}
        self.mesh_construction_kwargs: Mapping[str, Any] = mesh_construction_kwargs

    def set_up_nodes(self, count: int) -> None:
        # Preallocate array of nodes within list; treat None as sentinel value.
        # Preallocation not done for performance, but to assign values at indices
        # in random order.
        self.points = [None] * count

    def add_node(self, node_nr: int, point: Point) -> None:
        self.points[node_nr] = point

    @override
    def finalize_nodes(self) -> None:
        self.points = np.array(self.points, dtype=np.float64)

    @override
    def set_up_elements(self, count: int) -> None:
        # Preallocation of arrays for assignment elements in random order.
        self.element_vertices = [None] * count
        self.element_nodes = [None] * count
        self.element_types = [None] * count
        self.element_markers = [None] * count

    @override
    def add_element(self,
                    element_nr: int,
                    element_type: GmshElementBase,
                    vertex_nrs: IndexArray,
                    lexicographic_nodes: IndexArray,
                    tag_numbers: Sequence[int]) -> None:
        self.element_vertices[element_nr] = vertex_nrs
        self.element_nodes[element_nr] = lexicographic_nodes
        self.element_types[element_nr] = element_type
        if tag_numbers:
            self.element_markers[element_nr] = tag_numbers

    @override
    def finalize_elements(self) -> None:
        pass

    # May raise ValueError if called multiple times with the same name
    @override
    def add_tag(self, name: str, index: int, dimension: int) -> None:
        # add tag if new
        if index not in self.gmsh_tag_index_to_mine:
            self.gmsh_tag_index_to_mine[index] = len(self.tags)
            self.tags.append((name, dimension))
        else:
            # ensure not trying to add different tags with same index
            my_index = self.gmsh_tag_index_to_mine[index]
            recorded_name, recorded_dim = self.tags[my_index]
            if recorded_name != name or recorded_dim != dimension:
                raise ValueError("Distinct tags with the same tag id")

    @override
    def finalize_tags(self) -> None:
        pass

    @overload
    def get_mesh(
        self,
        return_tag_to_elements_map: Literal[True]
        ) -> tuple[Mesh, dict[str, IndexArray]]: ...

    @overload
    def get_mesh(
        self,
        return_tag_to_elements_map: Literal[False],
        ) -> Mesh: ...

    def get_mesh(self, return_tag_to_elements_map: bool = False):
        el_type_hist: dict[GmshElementBase, int] = {}
        for el_type in self.element_types:
            assert el_type is not None
            el_type_hist[el_type] = el_type_hist.get(el_type, 0) + 1

        if not el_type_hist:
            raise RuntimeError("empty mesh in gmsh input")

        assert isinstance(self.points, np.ndarray)
        groups = self.groups = []
        ambient_dim = self.points.shape[-1]

        mesh_bulk_dim = max(el_type.dimensions for el_type in el_type_hist)

        # map set of face vertex indices to list of tags associated to face
        face_vertex_indices_to_tags: dict[frozenset[int], list[str]] = {}
        for element_nr, el_vertices in enumerate(self.element_vertices):
            assert el_vertices is not None
            if self.tags:
                el_markers = self.element_markers[element_nr]
                el_tag_indexes = (
                    [mytag
                        for t in el_markers
                        if (mytag := self.gmsh_tag_index_to_mine.get(t)) is not None]
                    if el_markers is not None else [])
                # record tags of boundary dimension
                el_tags = [self.tags[i][0] for i in el_tag_indexes if
                           self.tags[i][1] == mesh_bulk_dim - 1]
                face_vertex_indices = frozenset(
                                    {int(i) for i in el_vertices})
                if face_vertex_indices not in face_vertex_indices_to_tags:
                    face_vertex_indices_to_tags[face_vertex_indices] = []
                face_vertex_indices_to_tags[face_vertex_indices] += el_tags

        # {{{ build vertex array

        vertices = np.asarray(self.points.T, dtype=np.float64, order="C")

        # }}}

        from meshmode.mesh import (
            SimplexElementGroup,
            TensorProductElementGroup,
            make_mesh,
        )

        bulk_el_types = set()

        group_base_elem_nr = 0

        tag_to_elements = {}

        for group_el_type, ngroup_elements in el_type_hist.items():
            if group_el_type.dimensions != mesh_bulk_dim:
                continue

            bulk_el_types.add(group_el_type)

            nodes = np.empty(
                    (ambient_dim, ngroup_elements, group_el_type.node_count()),
                    np.float64)
            el_vertex_count = group_el_type.vertex_count()
            vertex_indices = np.empty(
                    (ngroup_elements, el_vertex_count),
                    np.int32)
            i = 0

            for el_vertices, el_nodes, el_type, el_markers in zip(
                    self.element_vertices,
                    self.element_nodes,
                    self.element_types,
                    self.element_markers, strict=True):
                if el_type is not group_el_type:
                    continue

                assert el_vertices is not None

                assert isinstance(self.points, np.ndarray)
                nodes[:, i] = self.points[el_nodes].T
                vertex_indices[i] = el_vertices

                if el_markers is not None:
                    for t in el_markers:
                        mytag = self.gmsh_tag_index_to_mine.get(t)
                        if mytag is None:
                            continue

                        tag, _dim = self.tags[self.gmsh_tag_index_to_mine[t]]
                        if tag not in tag_to_elements:
                            tag_to_elements[tag] = [group_base_elem_nr + i]
                        else:
                            tag_to_elements[tag].append(group_base_elem_nr + i)

                i += 1

            import modepy as mp
            if isinstance(group_el_type, GmshSimplexElementBase):
                shape: Shape = mp.Simplex(group_el_type.dimensions)
            elif isinstance(group_el_type, GmshTensorProductElementBase):
                shape = mp.Hypercube(group_el_type.dimensions)
            else:
                raise NotImplementedError(
                        f"gmsh element type: {type(group_el_type).__name__}")

            space = mp.space_for_shape(shape, group_el_type.order)
            unit_nodes = mp.equispaced_nodes_for_space(space, shape)

            if isinstance(group_el_type, GmshSimplexElementBase):
                group: MeshElementGroup = SimplexElementGroup.make_group(
                    group_el_type.order,
                    vertex_indices,
                    nodes,
                    unit_nodes=unit_nodes
                    )

                if group.dim == 2:
                    from meshmode.mesh.processing import flip_element_group
                    group = flip_element_group(vertices, group,
                            np.ones(ngroup_elements, bool))

            elif isinstance(group_el_type, GmshTensorProductElementBase):
                vertex_shuffle = type(group_el_type)(
                        order=1).get_lexicographic_gmsh_node_indices()

                group = TensorProductElementGroup.make_group(
                    group_el_type.order,
                    vertex_indices[:, vertex_shuffle],
                    nodes,
                    unit_nodes=unit_nodes
                    )
            else:
                # NOTE: already checked above
                raise AssertionError()

            groups.append(group)

            group_base_elem_nr += group.nelements

        tag_to_elements_ary = {tag: np.array(els, dtype=np.int32)
            for tag, els in tag_to_elements.items()}

        # FIXME: This is heuristic.
        if len(bulk_el_types) == 1:
            is_conforming = True
        else:
            is_conforming = mesh_bulk_dim < 3

        # compute facial adjacency for Mesh if there is tag information
        facial_adjacency_groups = None
        if is_conforming and self.tags:
            from meshmode.mesh import _compute_facial_adjacency_from_vertices
            facial_adjacency_groups = _compute_facial_adjacency_from_vertices(
                    groups, np.dtype(np.int32), np.dtype(np.int8),
                    face_vertex_indices_to_tags)

        mesh = make_mesh(
                vertices, groups,
                is_conforming=is_conforming,
                facial_adjacency_groups=facial_adjacency_groups,
                **self.mesh_construction_kwargs)

        return (mesh, tag_to_elements_ary) if return_tag_to_elements_map else mesh

# }}}


# {{{ gmsh

AXIS_NAMES = "xyz"


@overload
def read_gmsh(
            filename: str,
            force_ambient_dim: int | None = None,
            *, mesh_construction_kwargs: Mapping[str, Any] | None = None,
            return_tag_to_elements_map: Literal[False] = False
        ) -> Mesh: ...


@overload
def read_gmsh(
            filename: str,
            force_ambient_dim: int | None = None,
            *, mesh_construction_kwargs: Mapping[str, Any] | None = None,
            return_tag_to_elements_map: Literal[True],
        ) -> tuple[Mesh, dict[str, IndexArray]]: ...


def read_gmsh(
            filename: str,
            force_ambient_dim: int | None = None,
            *, mesh_construction_kwargs: Mapping[str, Any] | None = None,
            return_tag_to_elements_map: bool = False
        ) -> tuple[Mesh, dict[str, IndexArray]] | Mesh:
    """Read a gmsh mesh file from *filename* and return a
    :class:`meshmode.mesh.Mesh`.

    :arg force_ambient_dim: if not None, truncate point coordinates to
        this many dimensions.
    :arg mesh_construction_kwargs: *None* or a dictionary of keyword
        arguments passed to the :class:`meshmode.mesh.Mesh` constructor.
    :arg return_tag_to_elements_map: If *True*, return in addition to the mesh
        a :class:`dict` that maps each volume tag in the gmsh file to a
        :class:`numpy.ndarray` containing meshwide indices of the elements that
        belong to that volume.
    """
    from gmsh_interop.reader import read_gmsh
    recv = GmshMeshReceiver(mesh_construction_kwargs=mesh_construction_kwargs)
    read_gmsh(recv, filename, force_dimension=force_ambient_dim)

    return recv.get_mesh(return_tag_to_elements_map=return_tag_to_elements_map)


@overload
def generate_gmsh(
            source: GmshSource,
            dimensions: int | None = None,
            order: int | None = None,
            *, other_options: Sequence[str] | None = None,
            extension: str = "geo",
            gmsh_executable: str = "gmsh",
            force_ambient_dim: int | None = None,
            output_file_path: str | None = None,
            mesh_construction_kwargs: Mapping[str, Any] | None = None,
            target_unit: Literal["M", "MM"] | None = None,
            return_tag_to_elements_map: Literal[False] = False,
            output_file_name: str | None = None,
        ) -> Mesh: ...


@overload
def generate_gmsh(
            source: GmshSource,
            dimensions: int | None = None,
            order: int | None = None,
            *, other_options: Sequence[str] | None = None,
            extension: str = "geo",
            gmsh_executable: str = "gmsh",
            force_ambient_dim: int | None = None,
            output_file_path: str | None = None,
            mesh_construction_kwargs: Mapping[str, Any] | None = None,
            target_unit: Literal["M", "MM"] | None = None,
            return_tag_to_elements_map: Literal[True],
            output_file_name: str | None = None,
        ) -> tuple[Mesh, dict[str, IndexArray]]: ...


def generate_gmsh(
            source: GmshSource,
            dimensions: int | None = None,
            order: int | None = None,
            *, other_options: Sequence[str] | None = None,
            extension: str = "geo",
            gmsh_executable: str = "gmsh",
            force_ambient_dim: int | None = None,
            output_file_path: str | None = None,
            mesh_construction_kwargs: Mapping[str, Any] | None = None,
            target_unit: Literal["M", "MM"] | None = None,
            return_tag_to_elements_map: bool = False,
            output_file_name: str | None = None,
        ) -> tuple[Mesh, dict[str, IndexArray]] | Mesh:
    """Run :command:`gmsh` on the input given by *source*, and return a
    :class:`meshmode.mesh.Mesh` based on the result.

    :arg source: an instance of either :class:`gmsh_interop.reader.FileSource` or
        :class:`gmsh_interop.reader.ScriptSource`
    :arg force_ambient_dim: if not *None*, truncate point coordinates to
        this many dimensions.
    :arg mesh_construction_kwargs: *None* or a dictionary of keyword
        arguments passed to the :class:`meshmode.mesh.Mesh` constructor.
    :arg target_unit: Value of the option *Geometry.OCCTargetUnit*.
        Supported values are the strings `'M'` or `'MM'`.
    """
    if other_options is None:
        other_options = []

    recv = GmshMeshReceiver(mesh_construction_kwargs=mesh_construction_kwargs)

    from gmsh_interop.reader import parse_gmsh
    from gmsh_interop.runner import GmshRunner

    if target_unit is None:
        target_unit = "MM"
        from warnings import warn
        warn(
                "Not specifying target_unit is deprecated. Set target_unit='MM' "
                "to retain prior behavior.", DeprecationWarning, stacklevel=2)

    if output_file_name is not None:
        from warnings import warn
        warn(
            "output_file_name is deprecated and will be removed in Q1 2026. "
            "Use output_file_path instead.", DeprecationWarning, stacklevel=2)
        output_file_path = output_file_name

    if output_file_path is not None:
        import os
        output_file_name = os.path.basename(output_file_path)
        save_output_file_in = os.path.dirname(output_file_path)
        if not save_output_file_in:
            save_output_file_in = os.getcwd()
    else:
        output_file_name = None
        save_output_file_in = None

    with GmshRunner(source, dimensions, order=order,
            other_options=other_options, extension=extension,
            gmsh_executable=gmsh_executable,
            output_file_name=output_file_name,
            target_unit=target_unit,
            save_output_file_in=save_output_file_in) as runner:
        assert runner.output_file
        parse_gmsh(recv, runner.output_file,
                force_dimension=force_ambient_dim)

    result = recv.get_mesh(return_tag_to_elements_map=return_tag_to_elements_map)

    if force_ambient_dim is None:
        if return_tag_to_elements_map:
            mesh = result[0]
        else:
            mesh = result

        dim = mesh.vertices.shape[0]
        for idim in range(dim):
            if (mesh.vertices[idim] == 0).all():
                from warnings import warn
                warn(f"all vertices' {AXIS_NAMES[idim]} coordinate is zero -- "
                     f"perhaps you want to pass force_ambient_dim={idim} (pass "
                     "any fixed value to force_ambient_dim to silence this warning)",
                     stacklevel=2)
                break

    return result

# }}}


# {{{ meshpy

def from_meshpy(mesh_info, order: int = 1) -> Mesh:
    """Imports a mesh from a :mod:`meshpy` *mesh_info* data structure,
    which may be generated by either :mod:`meshpy.triangle` or
    :mod:`meshpy.tet`.
    """
    from meshmode.mesh import make_mesh
    from meshmode.mesh.generation import make_group_from_vertices

    vertices: NDArray[np.floating] = np.array(mesh_info.points).T
    elements: NDArray[np.integer] = np.array(mesh_info.elements, np.int32)

    grp = make_group_from_vertices(vertices, elements, order)

    # FIXME: Should transfer boundary/volume markers

    return make_mesh(
            vertices=vertices, groups=[grp],
            is_conforming=True)

# }}}


# {{{ from_vertices_and_simplices

def from_vertices_and_simplices(
                vertices: np.ndarray,
                simplices: np.ndarray,
                order: int = 1,
                fix_orientation: bool = False
            ) -> Mesh:
    """Imports a mesh from a numpy array of vertices and an array
    of simplices.

    :arg vertices:
        An array of vertex coordinates with shape
        *(ambient_dim, nvertices)*
    :arg simplices:
        An array *(nelements, nvertices)* of (mesh-wide)
        vertex indices.
    """
    from meshmode.mesh import make_mesh
    from meshmode.mesh.generation import make_group_from_vertices

    grp = make_group_from_vertices(vertices, simplices, order)

    if fix_orientation:
        if grp.dim != vertices.shape[0]:
            raise ValueError("can only fix orientation of volume meshes")

        from meshmode.mesh.processing import (
            find_volume_mesh_element_group_orientation,
            flip_element_group,
        )
        orient = find_volume_mesh_element_group_orientation(vertices, grp)
        grp = flip_element_group(vertices, grp, orient < 0)

    return make_mesh(
            vertices=vertices, groups=[grp],
            is_conforming=True)

# }}}


# {{{ to_json

def to_json(mesh: Mesh) -> dict[str, Any]:
    """Return a JSON-able Python data structure for *mesh*. The structure directly
    reflects the :class:`meshmode.mesh.Mesh` data structure."""

    def group_to_json(group: MeshElementGroup):
        assert group.vertex_indices is not None
        return {
            "type": type(group).__name__,
            "order": group.order,
            "vertex_indices": group.vertex_indices.tolist(),
            "nodes": group.nodes.tolist(),
            "unit_nodes": group.unit_nodes.tolist(),
            "dim": group.dim,
            }

    from meshmode import DataUnavailableError

    def nodal_adjacency_to_json(mesh: Mesh):
        try:
            na = mesh.nodal_adjacency
        except DataUnavailableError:
            return None

        return {
            "neighbors_starts": na.neighbors_starts.tolist(),
            "neighbors": na.neighbors.tolist(),
            }

    return {
        # VERSION 0:
        # - initial version
        #
        # VERSION 1:
        # - added is_conforming

        "version": 1,
        "vertices": None if mesh.vertices is None else mesh.vertices.tolist(),
        "groups": [group_to_json(group) for group in mesh.groups],
        "nodal_adjacency": nodal_adjacency_to_json(mesh),
        # not yet implemented
        "facial_adjacency_groups": None,
        "is_conforming": mesh.is_conforming,
        }

# }}}


__all__ = [
    "FileSource",
    "LiteralSource",
    "ScriptSource",
    "ScriptWithFilesSource",
    "from_meshpy",
    "from_vertices_and_simplices",
    "generate_gmsh",
    "read_gmsh",
    "to_json",
]
# vim: foldmethod=marker
