From 049295e2d97fef4a7b7101b7c8ca7b54327a42ce Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Thu, 28 Jul 2022 00:14:02 -0500
Subject: [PATCH] adds PymbolicToASTMapper

---
 pymbolic/interop/ast.py | 251 +++++++++++++++++++++++++++++++++++++++-
 1 file changed, 250 insertions(+), 1 deletion(-)

diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py
index 6824d57..0756a22 100644
--- a/pymbolic/interop/ast.py
+++ b/pymbolic/interop/ast.py
@@ -1,4 +1,7 @@
-__copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
+__copyright__ = """
+Copyright (C) 2015 Andreas Kloeckner
+Copyright (C) 2022 Kaushik Kulkarni
+"""
 
 __license__ = """
 Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -22,6 +25,9 @@ THE SOFTWARE.
 
 import ast
 import pymbolic.primitives as p
+from typing import Tuple, List, Any
+from pymbolic.typing import ExpressionT, ScalarT
+from pymbolic.mapper import CachedMapper
 
 __doc__ = r'''
 
@@ -252,4 +258,247 @@ class ASTToPymbolic(ASTMapper):
 
 # }}}
 
+
+# {{{ PymbolicToASTMapper
+
+class PymbolicToASTMapper(CachedMapper):
+    def map_variable(self, expr) -> ast.expr:
+        return ast.Name(id=expr.name)
+
+    def _map_multi_children_op(self,
+                               children: Tuple[ExpressionT, ...],
+                               op_type: ast.operator) -> ast.expr:
+        rec_children = [self.rec(child) for child in children]
+        result = rec_children[-1]
+        for child in rec_children[-2::-1]:
+            result = ast.BinOp(child, op_type, result)
+
+        return result
+
+    def map_sum(self, expr: p.Sum) -> ast.expr:
+        return self._map_multi_children_op(expr.children, ast.Add())
+
+    def map_product(self, expr: p.Product) -> ast.expr:
+        return self._map_multi_children_op(expr.children, ast.Mult())
+
+    def map_constant(self, expr: ScalarT) -> ast.expr:
+        import sys
+        if isinstance(expr, bool):
+            return ast.NameConstant(expr)
+        else:
+            # needed because of https://bugs.python.org/issue36280
+            if sys.version_info < (3, 8):
+                return ast.Num(expr)
+            else:
+                return ast.Constant(expr, None)
+
+    def map_call(self, expr: p.Call) -> ast.expr:
+        return ast.Call(
+            func=self.rec(expr.function),
+            args=[self.rec(param) for param in expr.parameters])
+
+    def map_call_with_kwargs(self, expr) -> ast.expr:
+        return ast.Call(
+            func=self.rec(expr.function),
+            args=[self.rec(param) for param in expr.parameters],
+            keywords=[
+                ast.keyword(
+                    arg=kw,
+                    value=self.rec(param))
+                for kw, param in sorted(expr.kw_parameters.items())])
+
+    def map_subscript(self, expr) -> ast.expr:
+        return ast.Subscript(value=self.rec(expr.aggregate),
+                             slice=self.rec(expr.index))
+
+    def map_lookup(self, expr) -> ast.expr:
+        return ast.Attribute(self.rec(expr.aggregate),
+                             expr.name)
+
+    def map_quotient(self, expr) -> ast.expr:
+        return self._map_multi_children_op((expr.numerator,
+                                            expr.denominator),
+                                           ast.Div())
+
+    def map_floor_div(self, expr) -> ast.expr:
+        return self._map_multi_children_op((expr.numerator,
+                                            expr.denominator),
+                                           ast.FloorDiv())
+
+    def map_remainder(self, expr) -> ast.expr:
+        return self._map_multi_children_op((expr.numerator,
+                                            expr.denominator),
+                                           ast.Mod())
+
+    def map_power(self, expr) -> ast.expr:
+        return self._map_multi_children_op((expr.base,
+                                            expr.exponent),
+                                           ast.Pow())
+
+    def map_left_shift(self, expr) -> ast.expr:
+        return self._map_multi_children_op((expr.shiftee,
+                                            expr.shift),
+                                           ast.LShift())
+
+    def map_right_shift(self, expr) -> ast.expr:
+        return self._map_multi_children_op((expr.numerator,
+                                            expr.denominator),
+                                           ast.RShift())
+
+    def map_bitwise_not(self, expr) -> ast.expr:
+        return ast.UnaryOp(ast.Invert(), self.rec(expr.child))
+
+    def map_bitwise_or(self, expr) -> ast.expr:
+        return self._map_multi_children_op(expr.children,
+                                           ast.BitOr())
+
+    def map_bitwise_xor(self, expr) -> ast.expr:
+        return self._map_multi_children_op(expr.children,
+                                           ast.BitXor())
+
+    def map_bitwise_and(self, expr) -> ast.expr:
+        return self._map_multi_children_op(expr.children,
+                                           ast.BitAnd())
+
+    def map_logical_not(self, expr) -> ast.expr:
+        return ast.UnaryOp(self.rec(expr.child), ast.Not())
+
+    def map_logical_or(self, expr) -> ast.expr:
+        return ast.BoolOp(ast.Or(), [self.rec(child)
+                                     for child in expr.children])
+
+    def map_logical_and(self, expr) -> ast.expr:
+        return ast.BoolOp(ast.And(), [self.rec(child)
+                                     for child in expr.children])
+
+    def map_list(self, expr: List[Any]) -> ast.expr:
+        return ast.List([self.rec(el) for el in expr])
+
+    def map_tuple(self, expr: Tuple[Any, ...]) -> ast.expr:
+        return ast.Tuple([self.rec(el) for el in expr])
+
+    def map_if(self, expr: p.If) -> ast.expr:
+        return ast.IfExp(test=self.rec(expr.condition),
+                         body=self.rec(expr.then),
+                         orelse=self.rec(expr.else_))
+
+    def map_nan(self, expr: p.NaN) -> ast.expr:
+        if isinstance(expr.data_type(float("nan")), float):
+            return ast.Call(
+                ast.Name(id="float"),
+                args=[ast.Constant("nan")],
+                keywords=[])
+        else:
+            # TODO: would need attributes of NumPy
+            raise NotImplementedError("Non-float nan not implemented")
+
+    def map_slice(self, expr: p.Slice) -> ast.expr:
+        return ast.Slice(*[self.rec(child)
+                           for child in expr.children])
+
+    def map_numpy_array(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_multivector(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_common_subexpression(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_substitution(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_derivative(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_if_positive(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_comparison(self, expr: p.Comparison) -> ast.expr:
+        raise NotImplementedError
+
+    def map_polynomial(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_wildcard(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_dot_wildcard(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_star_wildcard(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_function_symbol(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_min(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+    def map_max(self, expr) -> ast.expr:
+        raise NotImplementedError
+
+
+def to_python_ast(expr) -> ast.expr:
+    """
+    Maps *expr* to :class:`ast.expr`.
+    """
+    return PymbolicToASTMapper()(expr)
+
+
+def to_evaluatable_python_function(expr: ExpressionT,
+                                   fn_name: str
+                                   ) -> str:
+    """
+    Returns a :class:`str` of the python code with a single function *fn_name*
+    that takes in the variables in *expr* as keyword-only arguments and returns
+    the evaluated value of *expr*.
+
+    .. testsetup::
+
+        >>> from pymbolic import parse
+        >>> from pymbolic.interop.ast import to_evaluatable_python_function
+
+    .. doctest::
+
+        >>> expr = parse("S//32 + E%32")
+        >>> # Skipping doctest as astunparse and ast.unparse have certain subtle
+        >>> # differences
+        >>> print(to_evaluatable_python_function(expr, "foo"))) # doctest: +SKIP
+        def foo(*, E, S):
+            return S // 32 + E % 32
+    """
+    import sys
+    from pymbolic.mapper.dependency import CachedDependencyMapper
+
+    if sys.version_info < (3, 9):
+        try:
+            from astunparse import unparse
+        except ImportError:
+            raise RuntimeError("'to_evaluate_python_function' needs"
+                               "astunparse for Py<3.9. Install via `pip"
+                               " install astunparse`")
+    else:
+        unparse = ast.unparse
+
+    dep_mapper = CachedDependencyMapper(composite_leaves=True)
+    deps = sorted({dep.name for dep in dep_mapper(expr)})
+
+    ast_func = ast.FunctionDef(name=fn_name,
+                               args=ast.arguments(args=[],
+                                                  posonlyargs=[],
+                                                  kwonlyargs=[ast.arg(dep, None)
+                                                              for dep in deps],
+                                                  kw_defaults=[None]*len(deps),
+                                                  vararg=None,
+                                                  kwarg=None,
+                                                  defaults=[]),
+                               body=[ast.Return(to_python_ast(expr))],
+                               decorator_list=[])
+    ast_module = ast.Module([ast_func], type_ignores=[])
+
+    return unparse(ast.fix_missing_locations(ast_module))
+
+# }}}
+
 # vim: foldmethod=marker
-- 
GitLab