from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import TYPE_CHECKING, Any, ClassVar, Concatenate
from warnings import warn

from typing_extensions import deprecated, override

from pytools import ndindex

import pymbolic.primitives as p
from pymbolic.mapper import CachedMapper, Mapper, P


if TYPE_CHECKING:
    from collections.abc import Iterator, Sequence

    import numpy as np
    from numpy.typing import NDArray

    from pymbolic.geometric_algebra import MultiVector
    from pymbolic.typing import ArithmeticExpression, Expression


__doc__ = """
.. _prec-constants:

Precedence constants
********************

.. data:: PREC_CALL
.. data:: PREC_POWER
.. data:: PREC_UNARY
.. data:: PREC_PRODUCT
.. data:: PREC_SUM
.. data:: PREC_SHIFT
.. data:: PREC_BITWISE_AND
.. data:: PREC_BITWISE_XOR
.. data:: PREC_BITWISE_OR
.. data:: PREC_COMPARISON
.. data:: PREC_LOGICAL_AND
.. data:: PREC_LOGICAL_OR
.. data:: PREC_IF
.. data:: PREC_NONE

Mappers
*******

.. autoclass:: StringifyMapper
    :show-inheritance:
.. autoclass:: SortingStringifyMapper
    :show-inheritance:
.. autoclass:: SimplifyingSortingStringifyMapper
    :show-inheritance:
.. autoclass:: CSESplittingStringifyMapperMixin
.. autoclass:: LaTeXMapper
    :show-inheritance:
"""


PREC_CALL = 15
PREC_POWER = 14
PREC_UNARY = 13
PREC_PRODUCT = 12
PREC_SUM = 11
PREC_SHIFT = 10
PREC_BITWISE_AND = 9
PREC_BITWISE_XOR = 8
PREC_BITWISE_OR = 7
PREC_COMPARISON = 6
PREC_LOGICAL_AND = 5
PREC_LOGICAL_OR = 4
PREC_IF = 3
PREC_NONE = 0


# {{{ stringifier


class StringifyMapper(Mapper[str, Concatenate[int, P]]):
    """A mapper to turn an expression tree into a string.

    :class:`pymbolic.ExpressionNode.__str__` is often implemented using
    this mapper.

    When it encounters an unsupported :class:`pymbolic.ExpressionNode`
    subclass, it calls its :meth:`pymbolic.ExpressionNode.make_stringifier`
    method to get a :class:`StringifyMapper` that potentially does.

    .. automethod:: __call__
    """

    # {{{ replaceable string composition interface

    def format(self, s: str, *args: object) -> str:
        return s % args

    def join(self, joiner: str, seq: Sequence[Expression]) -> str:
        return self.format(joiner.join("%s" for _ in seq), *seq)

    # {{{ deprecated junk

    @deprecated("interface not type-safe, use rec_with_parens_around_types")
    def rec_with_force_parens_around(
                self, expr, *args: P.args, **kwargs: P.kwargs) -> str:
        warn(
            "rec_with_force_parens_around is deprecated and will be removed in 2025. "
            "Use rec_with_parens_around_types instead. ",
            DeprecationWarning,
            stacklevel=2,
        )
        # Not currently possible to make this type-safe:
        # https://peps.python.org/pep-0612/#concatenating-keyword-parameters

        force_parens_around = kwargs.pop("force_parens_around", ())

        result = self.rec(expr, *args, **kwargs)

        if isinstance(expr, force_parens_around):
            result = f"({result})"

        return result

    def join_rec(
        self,
        joiner: str,
        seq: Sequence[Expression],
        prec: int,
        *args: P.args,
        **kwargs: P.kwargs,  # force_with_parens_around may hide in here
    ) -> str:
        f = joiner.join("%s" for _ in seq)

        if "force_parens_around" in kwargs:
            warn(
                "Passing force_parens_around join_rec is deprecated and will be "
                "removed in 2025. "
                "Use join_rec_with_parens_around_types instead. ",
                DeprecationWarning,
                stacklevel=2,
            )
            # Not currently possible to make this type-safe:
            # https://peps.python.org/pep-0612/#concatenating-keyword-parameters
            parens_around_types: tuple[type, ...] = kwargs.pop("force_parens_around")
            return self.join_rec_with_parens_around_types(
                joiner, seq, prec, parens_around_types, *args, **kwargs
            )

        return self.format(
            f,
            *[self.rec(i, prec, *args, **kwargs) for i in seq],
        )

    # }}}

    def rec_with_parens_around_types(
        self,
        expr: Expression,
        enclosing_prec: int,
        parens_around: tuple[type, ...],
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        result = self.rec(expr, enclosing_prec, *args, **kwargs)

        if isinstance(expr, parens_around):
            result = f"({result})"

        return result

    def join_rec_with_parens_around_types(
        self,
        joiner: str,
        seq: Sequence[Expression],
        prec: int,
        parens_around_types: tuple[type, ...],
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        f = joiner.join("%s" for _ in seq)
        return self.format(
            f,
            *[
                self.rec_with_parens_around_types(
                    i, prec, parens_around_types, *args, **kwargs
                )
                for i in seq
            ],
        )

    def parenthesize(self, s: str) -> str:
        return f"({s})"

    def parenthesize_if_needed(self, s: str, enclosing_prec: int, my_prec: int) -> str:
        if enclosing_prec > my_prec:
            return f"({s})"
        else:
            return s

    # }}}

    # {{{ mappings

    @override
    def handle_unsupported_expression(
        self, expr: p.ExpressionNode,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        strifier = expr.make_stringifier(self)
        if isinstance(self, type(strifier)):
            raise ValueError(f"stringifier '{self}' cannot handle '{type(expr)}'")

        return strifier(expr, enclosing_prec, *args, **kwargs)

    @override
    def map_constant(
        self, expr: object, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        result = str(expr)

        if (
            not (result.startswith("(") and result.endswith(")"))
            and ("-" in result or "+" in result)
            and (enclosing_prec > PREC_SUM)
        ):
            return self.parenthesize(result)
        else:
            return result

    @override
    def map_variable(self, expr: p.Variable, /,
                     enclosing_prec: int, *args: P.args, **kwargs: P.kwargs) -> str:
        return expr.name

    @override
    def map_wildcard(self, expr: p.Wildcard, /,
                     enclosing_prec: int, *args: P.args, **kwargs: P.kwargs) -> str:
        return "*"

    @override
    def map_function_symbol(
        self,
        expr: p.FunctionSymbol, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return expr.__class__.__name__

    @override
    def map_call(
        self, expr: p.Call, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.format(
            "%s(%s)",
            self.rec(expr.function, PREC_CALL, *args, **kwargs),
            self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs),
        )

    @override
    def map_call_with_kwargs(
        self,
        expr: p.CallWithKwargs, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        args_strings = tuple([
            self.rec(ch, PREC_NONE, *args, **kwargs) for ch in expr.parameters
        ]) + tuple([
            f"{name}={self.rec(ch, PREC_NONE, *args, **kwargs)}"
            for name, ch in expr.kw_parameters.items()
        ])
        return self.format(
            "%s(%s)",
            self.rec(expr.function, PREC_CALL, *args, **kwargs),
            ", ".join(args_strings),
        )

    @override
    def map_subscript(
        self, expr: p.Subscript, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        if isinstance(expr.index, tuple):
            index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs)
        else:
            index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs)

        return self.parenthesize_if_needed(
            self.format(
                "%s[%s]",
                self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
                index_str,
            ),
            enclosing_prec,
            PREC_CALL,
        )

    @override
    def map_lookup(
        self, expr: p.Lookup, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s.%s", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), expr.name
            ),
            enclosing_prec,
            PREC_CALL,
        )

    @override
    def map_sum(
        self, expr: p.Sum, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs),
            enclosing_prec,
            PREC_SUM,
        )

    # {{{ multiplicative operators

    multiplicative_primitives: tuple[type[p.ExpressionNode], ...] = (
        p.Product, p.Quotient, p.FloorDiv, p.Remainder)

    @override
    def map_product(
        self, expr: p.Product, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec_with_parens_around_types(
                "*",
                expr.children,
                PREC_PRODUCT,
                (p.Quotient, p.FloorDiv, p.Remainder),
                *args,
                **kwargs,
            ),
            enclosing_prec,
            PREC_PRODUCT,
        )

    @override
    def map_quotient(
        self, expr: p.Quotient, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s / %s",
                # space is necessary--otherwise '/*' becomes
                # start-of-comment in C. ('*' from dereference)
                self.rec_with_parens_around_types(
                    expr.numerator,
                    PREC_PRODUCT,
                    self.multiplicative_primitives,
                    *args,
                    **kwargs,
                ),
                self.rec_with_parens_around_types(
                    expr.denominator,
                    PREC_PRODUCT,
                    self.multiplicative_primitives,
                    *args,
                    **kwargs,
                ),
            ),
            enclosing_prec,
            PREC_PRODUCT,
        )

    @override
    def map_floor_div(
        self, expr: p.FloorDiv, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s // %s",
                self.rec_with_parens_around_types(
                    expr.numerator,
                    PREC_PRODUCT,
                    self.multiplicative_primitives,
                    *args,
                    **kwargs,
                ),
                self.rec_with_parens_around_types(
                    expr.denominator,
                    PREC_PRODUCT,
                    self.multiplicative_primitives,
                    *args,
                    **kwargs,
                ),
            ),
            enclosing_prec,
            PREC_PRODUCT,
        )

    @override
    def map_remainder(
        self, expr: p.Remainder, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s %% %s",
                self.rec_with_parens_around_types(
                    expr.numerator,
                    PREC_PRODUCT,
                    self.multiplicative_primitives,
                    *args,
                    **kwargs,
                ),
                self.rec_with_parens_around_types(
                    expr.denominator,
                    PREC_PRODUCT,
                    self.multiplicative_primitives,
                    *args,
                    **kwargs,
                ),
            ),
            enclosing_prec,
            PREC_PRODUCT,
        )

    # }}}

    @override
    def map_power(
        self, expr: p.Power, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s**%s",
                self.rec(expr.base, PREC_POWER, *args, **kwargs),
                self.rec(expr.exponent, PREC_POWER, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_POWER,
        )

    @override
    def map_left_shift(
        self, expr: p.LeftShift, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            # +1 to address
            # https://gitlab.tiker.net/inducer/pymbolic/issues/6
            self.format(
                "%s << %s",
                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_SHIFT,
        )

    @override
    def map_right_shift(
        self, expr: p.RightShift, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            # +1 to address
            # https://gitlab.tiker.net/inducer/pymbolic/issues/6
            self.format(
                "%s >> %s",
                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_SHIFT,
        )

    @override
    def map_bitwise_not(
        self,
        expr: p.BitwiseNot, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
            enclosing_prec,
            PREC_UNARY,
        )

    @override
    def map_bitwise_or(
        self, expr: p.BitwiseOr, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" | ", expr.children, PREC_BITWISE_OR, *args, **kwargs),
            enclosing_prec,
            PREC_BITWISE_OR,
        )

    @override
    def map_bitwise_xor(
        self,
        expr: p.BitwiseXor, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs),
            enclosing_prec,
            PREC_BITWISE_XOR,
        )

    @override
    def map_bitwise_and(
        self,
        expr: p.BitwiseAnd, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" & ", expr.children, PREC_BITWISE_AND, *args, **kwargs),
            enclosing_prec,
            PREC_BITWISE_AND,
        )

    @override
    def map_comparison(
        self, expr: p.Comparison, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s %s %s",
                self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
                expr.operator,
                self.rec(expr.right, PREC_COMPARISON, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_COMPARISON,
        )

    @override
    def map_logical_not(
        self,
        expr: p.LogicalNot, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
            enclosing_prec,
            PREC_UNARY,
        )

    @override
    def map_logical_or(
        self, expr: p.LogicalOr, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
            enclosing_prec,
            PREC_LOGICAL_OR,
        )

    @override
    def map_logical_and(
        self,
        expr: p.LogicalAnd, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
            enclosing_prec,
            PREC_LOGICAL_AND,
        )

    @override
    def map_list(
        self,
        expr: list[Expression], /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.format(
            "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs)
        )

    @override
    def map_tuple(
        self,
        expr: tuple[Expression, ...], /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        el_str = ", ".join(
            self.rec(child, PREC_NONE, *args, **kwargs) for child in expr
        )
        if len(expr) == 1:
            el_str += ","

        return f"({el_str})"

    @override
    def map_numpy_array(
        self,
        expr: NDArray[np.generic], /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        import numpy

        str_array = numpy.zeros(expr.shape, dtype="object")
        max_length = 0
        for i in ndindex(expr.shape):
            s = self.rec(expr[i], PREC_NONE, *args, **kwargs)
            max_length = max(len(s), max_length)
            str_array[i] = s.replace("\n", "\n  ")

        if len(expr.shape) == 1 and max_length < 15:
            return "array({})".format(", ".join(str_array))
        else:
            lines = [
                "  {}: {}\n".format(",".join(str(i_i) for i_i in i), val)
                for i, val in numpy.ndenumerate(str_array)
            ]
            splitter = ("  " + "-" * 75 + "\n") if max_length > 70 else ""

            return f"array(\n{splitter.join(lines)})"

    @override
    def map_multivector(
        self,
        expr: MultiVector[Any], /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return expr.stringify(self.rec, enclosing_prec, *args, **kwargs)

    def map_common_subexpression(
        self,
        expr: p.CommonSubexpression, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        type_name = (
                "CSE" if type(expr) is p.CommonSubexpression
                else type(expr).__name__)

        return self.format(
            "%s(%s)", type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs)
        )

    @override
    def map_if(
        self, expr: p.If, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        then_ = self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs)
        cond_ = self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs)
        else_ = self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)

        return self.parenthesize_if_needed(
            f"{then_} if {cond_} else {else_}",
            enclosing_prec,
            PREC_IF,
        )

    @override
    def map_min(
        self, expr: p.Min, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        what = type(expr).__name__.lower()
        return self.format(
            "%s(%s)",
            what,
            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
        )

    @override
    def map_max(
        self, expr: p.Max, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        what = type(expr).__name__.lower()
        return self.format(
            "%s(%s)",
            what,
            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
        )

    @override
    def map_derivative(
        self, expr: p.Derivative, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        ddv = " ".join(f"d/d{v}" for v in expr.variables)
        child = self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)
        return f"{ddv} {child}"

    @override
    def map_substitution(
        self,
        expr: p.Substitution, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        substs = ", ".join(
            f"{name}={self.rec(val, PREC_NONE, *args, **kwargs)}"
            for name, val in zip(expr.variables, expr.values, strict=True)
        )
        child = self.rec(expr.child, PREC_NONE, *args, **kwargs)

        return f"[{child}]{substs}"

    @override
    def map_slice(
        self, expr: p.Slice, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        children: list[str] = []
        for child in expr.children:
            if child is None:
                children.append("")
            else:
                children.append(self.rec(child, PREC_NONE, *args, **kwargs))

        return self.parenthesize_if_needed(
            ":".join(children), enclosing_prec, PREC_NONE
        )

    @override
    def map_nan(
        self, expr: p.NaN, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return "NaN"

    # }}}

    @override
    def __call__(self, expr: Expression, /, prec: int = PREC_NONE,
                 *args: P.args, **kwargs: P.kwargs) -> str:
        """Return a string corresponding to *expr*. If the enclosing
        precedence level *prec* is higher than *prec* (see :ref:`prec-constants`),
        parenthesize the result.
        """

        return Mapper.__call__(self, expr, prec, *args, **kwargs)


class CachedStringifyMapper(StringifyMapper[P],
                            CachedMapper[str, Concatenate[int, P]]):
    def __init__(self) -> None:
        StringifyMapper.__init__(self)
        CachedMapper.__init__(self)

    @override
    def __call__(self, expr: Expression, /, prec: int = PREC_NONE,
                 *args: P.args, **kwargs: P.kwargs) -> str:
        return CachedMapper.__call__(self, expr, prec, *args, **kwargs)


# }}}


# {{{ cse-splitting stringifier

class CSESplittingStringifyMapperMixin(Mapper[str, Concatenate[int, P]]):
    """A :term:`mix-in` for subclasses of
    :class:`StringifyMapper` that collects
    "variable assignments" for
    :class:`pymbolic.primitives.CommonSubexpression` objects.

    .. attribute:: cse_to_name

        A :class:`dict` mapping expressions to CSE variable names.

    .. attribute:: cse_names

        A :class:`set` of names already assigned.

    .. attribute:: cse_name_list

        A :class:`list` of tuples of names and their string representations,
        in order of their dependencies. When generating code, walk down these names
        in order, and the generated code will never reference
        an undefined variable.

    See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example
    of the use of this mix-in.
    """

    cse_to_name: dict[Expression, str]
    cse_names: set[str]
    cse_name_list: list[tuple[str, str]]

    def __init__(self) -> None:
        self.cse_to_name = {}
        self.cse_names = set()
        self.cse_name_list = []

        super().__init__()

    def map_common_subexpression(
        self,
        expr: p.CommonSubexpression, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        try:
            cse_name = self.cse_to_name[expr.child]
        except KeyError:
            str_child = self.rec(expr.child, PREC_NONE, *args, **kwargs)

            if expr.prefix is not None:

                def generate_cse_names() -> Iterator[str]:
                    assert expr.prefix is not None

                    yield expr.prefix
                    i = 2
                    while True:
                        yield f"{expr.prefix}_{i}"
                        i += 1

            else:

                def generate_cse_names() -> Iterator[str]:
                    i = 0
                    while True:
                        yield "CSE" + str(i)
                        i += 1

            cse_name = None
            for cse_name in generate_cse_names():
                if cse_name not in self.cse_names:
                    break
            assert cse_name is not None

            self.cse_name_list.append((cse_name, str_child))
            self.cse_to_name[expr.child] = cse_name
            self.cse_names.add(cse_name)

        return cse_name

    def get_cse_strings(self) -> list[str]:
        cse_name_list: list[tuple[str, str]] = getattr(self, "cse_name_list", [])
        return [
            f"{cse_name} : {cse_str}"
            for cse_name, cse_str in sorted(cse_name_list)
        ]


# }}}


# {{{ sorting stringifier


class SortingStringifyMapper(StringifyMapper[P]):
    reverse: bool

    def __init__(self, reverse: bool = True) -> None:
        super().__init__()
        self.reverse = reverse

    @override
    def map_sum(
        self, expr: p.Sum, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        entries = [self.rec(i, PREC_SUM, *args, **kwargs) for i in expr.children]
        entries.sort(reverse=self.reverse)
        return self.parenthesize_if_needed("+".join(entries), enclosing_prec, PREC_SUM)

    @override
    def map_product(
        self, expr: p.Product, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        entries = [self.rec(i, PREC_PRODUCT, *args, **kwargs) for i in expr.children]
        entries.sort(reverse=self.reverse)
        return self.parenthesize_if_needed(
            "*".join(entries), enclosing_prec, PREC_PRODUCT
        )


# }}}


# {{{ simplifying, sorting stringifier


class SimplifyingSortingStringifyMapper(StringifyMapper[P]):
    reverse: bool

    def __init__(self, reverse: bool = True) -> None:
        super().__init__()
        self.reverse = reverse

    @override
    def map_sum(
        self, expr: p.Sum, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        def get_neg_product(expr: ArithmeticExpression) -> ArithmeticExpression | None:
            if (
                isinstance(expr, p.Product)
                and len(expr.children)
                and p.is_zero(expr.children[0] + 1)
            ):
                if len(expr.children) == 2:
                    # only the minus sign and the other child
                    return expr.children[1]
                else:
                    return p.Product(expr.children[1:])
            else:
                return None

        positives: list[str] = []
        negatives: list[str] = []

        for ch in expr.children:
            neg_prod = get_neg_product(ch)
            if neg_prod is not None:
                negatives.append(self.rec(neg_prod, PREC_PRODUCT, *args, **kwargs))
            else:
                positives.append(self.rec(ch, PREC_SUM, *args, **kwargs))

        positives.sort(reverse=self.reverse)
        positives_str = " + ".join(positives)
        negatives.sort(reverse=self.reverse)
        negatives_str = "".join(self.format(" - %s", entry) for entry in negatives)

        result = positives_str + negatives_str

        return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM)

    @override
    def map_product(
        self, expr: p.Product, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        entries: list[str] = []
        i = 0

        while i < len(expr.children):
            child = expr.children[i]
            if False:
                # NOTE: That space needs to be there.
                # Otherwise two unary minus signs merge into a pre-decrement.
                entries.append(  # pyright: ignore[reportUnreachable]
                    self.format(
                        "- %s",
                        self.rec(expr.children[i + 1], PREC_UNARY, *args, **kwargs),
                    )
                )
                i += 2
            else:
                entries.append(self.rec(child, PREC_PRODUCT, *args, **kwargs))
                i += 1

        entries.sort(reverse=self.reverse)
        result = "*".join(entries)

        return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT)


# }}}


# {{{ latex stringifier


class LaTeXMapper(StringifyMapper[P]):
    COMPARISON_OP_TO_LATEX: ClassVar[dict[str, str]] = {
        "==": r"=",
        "!=": r"\ne",
        "<=": r"\le",
        ">=": r"\ge",
        "<": r"<",
        ">": r">",
    }

    @override
    def map_remainder(
        self, expr: p.Remainder, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.format(
            r"(%s \bmod %s)",
            self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs),
            self.rec(expr.denominator, PREC_POWER, *args, **kwargs),
        )

    @override
    def map_left_shift(
        self, expr: p.LeftShift, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                r"%s \ll %s",
                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_SHIFT,
        )

    @override
    def map_right_shift(
        self,
        expr: p.RightShift, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                r"%s \gg %s",
                self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs),
                self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_SHIFT,
        )

    @override
    def map_bitwise_xor(
        self,
        expr: p.BitwiseXor, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(
                r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs
            ),
            enclosing_prec,
            PREC_BITWISE_XOR,
        )

    @override
    def map_product(
        self, expr: p.Product, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs),
            enclosing_prec,
            PREC_PRODUCT,
        )

    @override
    def map_power(
        self, expr: p.Power, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "{%s}^{%s}",
                self.rec(expr.base, PREC_NONE, *args, **kwargs),
                self.rec(expr.exponent, PREC_NONE, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_NONE,
        )

    @override
    def map_min(
        self, expr: p.Min, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        from pytools import is_single_valued

        if is_single_valued(expr.children):
            return self.rec(expr.children[0], enclosing_prec, *args, **kwargs)

        what = type(expr).__name__.lower()
        return self.format(
            r"\%s(%s)",
            what,
            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
        )

    @override
    def map_max(
        self, expr: p.Max, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        from pytools import is_single_valued

        if is_single_valued(expr.children):
            return self.rec(expr.children[0], enclosing_prec, *args, **kwargs)

        what = type(expr).__name__.lower()
        return self.format(
            r"\%s(%s)",
            what,
            self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs),
        )

    @override
    def map_floor_div(
        self, expr: p.FloorDiv, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.format(
            r"\lfloor {%s} / {%s} \rfloor",
            self.rec(expr.numerator, PREC_NONE, *args, **kwargs),
            self.rec(expr.denominator, PREC_NONE, *args, **kwargs),
        )

    @override
    def map_subscript(
        self, expr: p.Subscript, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        if isinstance(expr.index, tuple):
            index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs)
        else:
            index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs)

        return self.format(
            "{%s}_{%s}", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), index_str
        )

    @override
    def map_logical_not(
        self,
        expr: p.LogicalNot, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
            enclosing_prec,
            PREC_UNARY,
        )

    @override
    def map_logical_or(
        self, expr: p.LogicalOr, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
            enclosing_prec,
            PREC_LOGICAL_OR,
        )

    @override
    def map_logical_and(
        self,
        expr: p.LogicalAnd, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        return self.parenthesize_if_needed(
            self.join_rec(
                r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs
            ),
            enclosing_prec,
            PREC_LOGICAL_AND,
        )

    @override
    def map_comparison(
        self, expr: p.Comparison, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        return self.parenthesize_if_needed(
            self.format(
                "%s %s %s",
                self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
                self.COMPARISON_OP_TO_LATEX[expr.operator],
                self.rec(expr.right, PREC_COMPARISON, *args, **kwargs),
            ),
            enclosing_prec,
            PREC_COMPARISON,
        )

    @override
    def map_substitution(
        self,
        expr: p.Substitution, /,
        enclosing_prec: int,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> str:
        substs = ", ".join(
            f"{name}={self.rec(val, PREC_NONE, *args, **kwargs)}"
            for name, val in zip(expr.variables, expr.values, strict=True)
        )

        return self.format(
            r"[%s]\{%s\}", self.rec(expr.child, PREC_NONE, *args, **kwargs), substs
        )

    @override
    def map_derivative(
        self, expr: p.Derivative, /,
        enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
    ) -> str:
        ddv = " ".join(fr"\frac{{\partial}}{{\partial {v}}}" for v in expr.variables)
        child = self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)

        return self.format("%s %s", ddv, child)


# }}}

# vim: fdm=marker
