From 049295e2d97fef4a7b7101b7c8ca7b54327a42ce Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Thu, 28 Jul 2022 00:14:02 -0500 Subject: [PATCH] adds PymbolicToASTMapper --- pymbolic/interop/ast.py | 251 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 250 insertions(+), 1 deletion(-) diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index 6824d57..0756a22 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -1,4 +1,7 @@ -__copyright__ = "Copyright (C) 2015 Andreas Kloeckner" +__copyright__ = """ +Copyright (C) 2015 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" __license__ = """ Permission is hereby granted, free of charge, to any person obtaining a copy @@ -22,6 +25,9 @@ THE SOFTWARE. import ast import pymbolic.primitives as p +from typing import Tuple, List, Any +from pymbolic.typing import ExpressionT, ScalarT +from pymbolic.mapper import CachedMapper __doc__ = r''' @@ -252,4 +258,247 @@ class ASTToPymbolic(ASTMapper): # }}} + +# {{{ PymbolicToASTMapper + +class PymbolicToASTMapper(CachedMapper): + def map_variable(self, expr) -> ast.expr: + return ast.Name(id=expr.name) + + def _map_multi_children_op(self, + children: Tuple[ExpressionT, ...], + op_type: ast.operator) -> ast.expr: + rec_children = [self.rec(child) for child in children] + result = rec_children[-1] + for child in rec_children[-2::-1]: + result = ast.BinOp(child, op_type, result) + + return result + + def map_sum(self, expr: p.Sum) -> ast.expr: + return self._map_multi_children_op(expr.children, ast.Add()) + + def map_product(self, expr: p.Product) -> ast.expr: + return self._map_multi_children_op(expr.children, ast.Mult()) + + def map_constant(self, expr: ScalarT) -> ast.expr: + import sys + if isinstance(expr, bool): + return ast.NameConstant(expr) + else: + # needed because of https://bugs.python.org/issue36280 + if sys.version_info < (3, 8): + return ast.Num(expr) + else: + return ast.Constant(expr, None) + + def map_call(self, expr: p.Call) -> ast.expr: + return ast.Call( + func=self.rec(expr.function), + args=[self.rec(param) for param in expr.parameters]) + + def map_call_with_kwargs(self, expr) -> ast.expr: + return ast.Call( + func=self.rec(expr.function), + args=[self.rec(param) for param in expr.parameters], + keywords=[ + ast.keyword( + arg=kw, + value=self.rec(param)) + for kw, param in sorted(expr.kw_parameters.items())]) + + def map_subscript(self, expr) -> ast.expr: + return ast.Subscript(value=self.rec(expr.aggregate), + slice=self.rec(expr.index)) + + def map_lookup(self, expr) -> ast.expr: + return ast.Attribute(self.rec(expr.aggregate), + expr.name) + + def map_quotient(self, expr) -> ast.expr: + return self._map_multi_children_op((expr.numerator, + expr.denominator), + ast.Div()) + + def map_floor_div(self, expr) -> ast.expr: + return self._map_multi_children_op((expr.numerator, + expr.denominator), + ast.FloorDiv()) + + def map_remainder(self, expr) -> ast.expr: + return self._map_multi_children_op((expr.numerator, + expr.denominator), + ast.Mod()) + + def map_power(self, expr) -> ast.expr: + return self._map_multi_children_op((expr.base, + expr.exponent), + ast.Pow()) + + def map_left_shift(self, expr) -> ast.expr: + return self._map_multi_children_op((expr.shiftee, + expr.shift), + ast.LShift()) + + def map_right_shift(self, expr) -> ast.expr: + return self._map_multi_children_op((expr.numerator, + expr.denominator), + ast.RShift()) + + def map_bitwise_not(self, expr) -> ast.expr: + return ast.UnaryOp(ast.Invert(), self.rec(expr.child)) + + def map_bitwise_or(self, expr) -> ast.expr: + return self._map_multi_children_op(expr.children, + ast.BitOr()) + + def map_bitwise_xor(self, expr) -> ast.expr: + return self._map_multi_children_op(expr.children, + ast.BitXor()) + + def map_bitwise_and(self, expr) -> ast.expr: + return self._map_multi_children_op(expr.children, + ast.BitAnd()) + + def map_logical_not(self, expr) -> ast.expr: + return ast.UnaryOp(self.rec(expr.child), ast.Not()) + + def map_logical_or(self, expr) -> ast.expr: + return ast.BoolOp(ast.Or(), [self.rec(child) + for child in expr.children]) + + def map_logical_and(self, expr) -> ast.expr: + return ast.BoolOp(ast.And(), [self.rec(child) + for child in expr.children]) + + def map_list(self, expr: List[Any]) -> ast.expr: + return ast.List([self.rec(el) for el in expr]) + + def map_tuple(self, expr: Tuple[Any, ...]) -> ast.expr: + return ast.Tuple([self.rec(el) for el in expr]) + + def map_if(self, expr: p.If) -> ast.expr: + return ast.IfExp(test=self.rec(expr.condition), + body=self.rec(expr.then), + orelse=self.rec(expr.else_)) + + def map_nan(self, expr: p.NaN) -> ast.expr: + if isinstance(expr.data_type(float("nan")), float): + return ast.Call( + ast.Name(id="float"), + args=[ast.Constant("nan")], + keywords=[]) + else: + # TODO: would need attributes of NumPy + raise NotImplementedError("Non-float nan not implemented") + + def map_slice(self, expr: p.Slice) -> ast.expr: + return ast.Slice(*[self.rec(child) + for child in expr.children]) + + def map_numpy_array(self, expr) -> ast.expr: + raise NotImplementedError + + def map_multivector(self, expr) -> ast.expr: + raise NotImplementedError + + def map_common_subexpression(self, expr) -> ast.expr: + raise NotImplementedError + + def map_substitution(self, expr) -> ast.expr: + raise NotImplementedError + + def map_derivative(self, expr) -> ast.expr: + raise NotImplementedError + + def map_if_positive(self, expr) -> ast.expr: + raise NotImplementedError + + def map_comparison(self, expr: p.Comparison) -> ast.expr: + raise NotImplementedError + + def map_polynomial(self, expr) -> ast.expr: + raise NotImplementedError + + def map_wildcard(self, expr) -> ast.expr: + raise NotImplementedError + + def map_dot_wildcard(self, expr) -> ast.expr: + raise NotImplementedError + + def map_star_wildcard(self, expr) -> ast.expr: + raise NotImplementedError + + def map_function_symbol(self, expr) -> ast.expr: + raise NotImplementedError + + def map_min(self, expr) -> ast.expr: + raise NotImplementedError + + def map_max(self, expr) -> ast.expr: + raise NotImplementedError + + +def to_python_ast(expr) -> ast.expr: + """ + Maps *expr* to :class:`ast.expr`. + """ + return PymbolicToASTMapper()(expr) + + +def to_evaluatable_python_function(expr: ExpressionT, + fn_name: str + ) -> str: + """ + Returns a :class:`str` of the python code with a single function *fn_name* + that takes in the variables in *expr* as keyword-only arguments and returns + the evaluated value of *expr*. + + .. testsetup:: + + >>> from pymbolic import parse + >>> from pymbolic.interop.ast import to_evaluatable_python_function + + .. doctest:: + + >>> expr = parse("S//32 + E%32") + >>> # Skipping doctest as astunparse and ast.unparse have certain subtle + >>> # differences + >>> print(to_evaluatable_python_function(expr, "foo"))) # doctest: +SKIP + def foo(*, E, S): + return S // 32 + E % 32 + """ + import sys + from pymbolic.mapper.dependency import CachedDependencyMapper + + if sys.version_info < (3, 9): + try: + from astunparse import unparse + except ImportError: + raise RuntimeError("'to_evaluate_python_function' needs" + "astunparse for Py<3.9. Install via `pip" + " install astunparse`") + else: + unparse = ast.unparse + + dep_mapper = CachedDependencyMapper(composite_leaves=True) + deps = sorted({dep.name for dep in dep_mapper(expr)}) + + ast_func = ast.FunctionDef(name=fn_name, + args=ast.arguments(args=[], + posonlyargs=[], + kwonlyargs=[ast.arg(dep, None) + for dep in deps], + kw_defaults=[None]*len(deps), + vararg=None, + kwarg=None, + defaults=[]), + body=[ast.Return(to_python_ast(expr))], + decorator_list=[]) + ast_module = ast.Module([ast_func], type_ignores=[]) + + return unparse(ast.fix_missing_locations(ast_module)) + +# }}} + # vim: foldmethod=marker -- GitLab