diff --git a/doc/mappers.rst b/doc/mappers.rst index ce56ea91269546489e57d5a664885c499aa0a25a..a3c60b6f4c143a9278f5704c641db36305abc170 100644 --- a/doc/mappers.rst +++ b/doc/mappers.rst @@ -12,15 +12,6 @@ Converting to strings and code .. automodule:: pymbolic.mapper.stringifier -Mappers -******* - -.. autoclass:: StringifyMapper - - .. automethod:: __call__ - -.. autoclass:: CSESplittingStringifyMapperMixin - .. automodule:: pymbolic.mapper.c_code .. autoclass:: CCodeMapper diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 44057dfb816ee9e374ef919b5061361d878aa80b..26c2932342a6ae9d28b1f7d4f88ea8ab4729c789 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -43,6 +43,17 @@ Precedence constants .. data:: PREC_LOGICAL_AND .. data:: PREC_LOGICAL_OR .. data:: PREC_NONE + +Mappers +******* + +.. autoclass:: StringifyMapper + + .. automethod:: __call__ + +.. autoclass:: CSESplittingStringifyMapperMixin + +.. autoclass:: LaTeXMapper """ @@ -562,4 +573,126 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): # }}} + +# {{{ latex stringifier + +class LaTeXMapper(StringifyMapper): + + COMPARISON_OP_TO_LATEX = { + "==": r"=", + "!=": r"\ne", + "<=": r"\le", + ">=": r"\ge", + "<": r"<", + ">": r">", + } + + def map_remainder(self, expr, enclosing_prec, *args, **kwargs): + return self.format(r"(%s \bmod %s)", + self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), + self.rec(expr.denominator, PREC_POWER, *args, **kwargs)), + + def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.format(r"%s \ll %s", + self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), + enclosing_prec, PREC_SHIFT) + + def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.format(r"%s \gg %s", + self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), + enclosing_prec, PREC_SHIFT) + + def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec( + r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), + enclosing_prec, PREC_BITWISE_XOR) + + def map_product(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, PREC_PRODUCT) + + def map_power(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.format("{%s}^{%s}", + self.rec(expr.base, PREC_NONE, *args, **kwargs), + self.rec(expr.exponent, PREC_NONE, *args, **kwargs)), + enclosing_prec, PREC_NONE) + + def map_min(self, expr, enclosing_prec, *args, **kwargs): + from pytools import is_single_valued + if is_single_valued(expr.children): + return self.rec(expr.children[0], enclosing_prec) + + what = type(expr).__name__.lower() + return self.format(r"\%s(%s)", + what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs)) + + def map_max(self, expr, enclosing_prec): + return self.map_min(expr, enclosing_prec) + + def map_floor_div(self, expr, enclosing_prec, *args, **kwargs): + return self.format(r"\lfloor {%s} / {%s} \rfloor", + self.rec(expr.numerator, PREC_NONE, *args, **kwargs), + self.rec(expr.denominator, PREC_NONE, *args, **kwargs)) + + def map_subscript(self, expr, enclosing_prec, *args, **kwargs): + if isinstance(expr.index, tuple): + index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs) + else: + index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs) + + return self.format("{%s}_{%s}", + self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), + index_str) + + def map_logical_not(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, PREC_UNARY) + + def map_logical_or(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec( + r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), + enclosing_prec, PREC_LOGICAL_OR) + + def map_logical_and(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec( + r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), + enclosing_prec, PREC_LOGICAL_AND) + + def map_comparison(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.format("%s %s %s", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + self.COMPARISON_OP_TO_LATEX[expr.operator], + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), + enclosing_prec, PREC_COMPARISON) + + def map_substitution(self, expr, enclosing_prec, *args, **kwargs): + substs = ", ".join( + "%s=%s" % (name, self.rec(val, PREC_NONE, *args, **kwargs)) + for name, val in zip(expr.variables, expr.values)) + + return self.format("[%s]\{%s\}", + self.rec(expr.child, PREC_NONE, *args, **kwargs), + substs) + + def map_derivative(self, expr, enclosing_prec, *args, **kwargs): + derivs = " ".join( + r"\frac{\partial}{\partial %s}" % v + for v in expr.variables) + + return self.format("%s %s", + derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)) + +# }}} + # vim: fdm=marker diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index f6b8733dc4d69ad87bb47985a605b116e4a6a699..bb92ac841a8c2be2cce9e4a704561408fd4e2b09 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -492,6 +492,74 @@ def test_stringifier_preserve_shift_order(): assert parse(str(expr)) == expr +LATEX_TEMPLATE = r"""\documentclass{article} +\usepackage{amsmath} + +\begin{document} +%s +\end{document}""" + + +def test_latex_mapper(): + from pymbolic import parse + from pymbolic.mapper.stringifier import LaTeXMapper, StringifyMapper + + tm = LaTeXMapper() + sm = StringifyMapper() + + equations = [] + + def add(expr): + # Add an equation to the list of tests. + equations.append("\[%s\] %% from: %s" % (tm(expr), sm(expr))) + + add(parse("a * b + c")) + add(parse("f(a,b,c)")) + add(parse("a ** b ** c")) + add(parse("(a | b) ^ ~c")) + add(parse("a << b")) + add(parse("a >> b")) + add(parse("a[i,j,k]")) + add(parse("a[1:3]")) + add(parse("a // b")) + add(parse("not (a or b) and c")) + add(parse("(a % b) % c")) + add(parse("(a >= b) or (b <= c)")) + add(prim.Min((1,)) + prim.Max((1, 2))) + add(prim.Substitution(prim.Variable("x") ** 2, ("x",), (2,))) + add(prim.Derivative(parse("x**2"), ("x",))) + + # Run LaTeX and ensure the file compiles. + import os + import tempfile + import subprocess + import shutil + + latex_dir = tempfile.mkdtemp("pymbolic") + + try: + tex_file_path = os.path.join(latex_dir, "input.tex") + + with open(tex_file_path, "w") as tex_file: + contents = LATEX_TEMPLATE % "\n".join(equations) + tex_file.write(contents) + + try: + subprocess.check_output( + ["latex", + "-interaction=nonstopmode", + "-output-directory=%s" % latex_dir, + tex_file_path], + universal_newlines=True) + except FileNotFoundError: + pytest.skip("latex command not found") + except subprocess.CalledProcessError as err: + assert False, str(err.output) + + finally: + shutil.rmtree(latex_dir) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: