Skip to content
Snippets Groups Projects
Commit 049295e2 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

adds PymbolicToASTMapper

parent a91e8653
No related branches found
No related tags found
No related merge requests found
__copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
__copyright__ = """
Copyright (C) 2015 Andreas Kloeckner
Copyright (C) 2022 Kaushik Kulkarni
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
......@@ -22,6 +25,9 @@ THE SOFTWARE.
import ast
import pymbolic.primitives as p
from typing import Tuple, List, Any
from pymbolic.typing import ExpressionT, ScalarT
from pymbolic.mapper import CachedMapper
__doc__ = r'''
......@@ -252,4 +258,247 @@ class ASTToPymbolic(ASTMapper):
# }}}
# {{{ PymbolicToASTMapper
class PymbolicToASTMapper(CachedMapper):
def map_variable(self, expr) -> ast.expr:
return ast.Name(id=expr.name)
def _map_multi_children_op(self,
children: Tuple[ExpressionT, ...],
op_type: ast.operator) -> ast.expr:
rec_children = [self.rec(child) for child in children]
result = rec_children[-1]
for child in rec_children[-2::-1]:
result = ast.BinOp(child, op_type, result)
return result
def map_sum(self, expr: p.Sum) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Add())
def map_product(self, expr: p.Product) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Mult())
def map_constant(self, expr: ScalarT) -> ast.expr:
import sys
if isinstance(expr, bool):
return ast.NameConstant(expr)
else:
# needed because of https://bugs.python.org/issue36280
if sys.version_info < (3, 8):
return ast.Num(expr)
else:
return ast.Constant(expr, None)
def map_call(self, expr: p.Call) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters])
def map_call_with_kwargs(self, expr) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters],
keywords=[
ast.keyword(
arg=kw,
value=self.rec(param))
for kw, param in sorted(expr.kw_parameters.items())])
def map_subscript(self, expr) -> ast.expr:
return ast.Subscript(value=self.rec(expr.aggregate),
slice=self.rec(expr.index))
def map_lookup(self, expr) -> ast.expr:
return ast.Attribute(self.rec(expr.aggregate),
expr.name)
def map_quotient(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Div())
def map_floor_div(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.FloorDiv())
def map_remainder(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Mod())
def map_power(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.base,
expr.exponent),
ast.Pow())
def map_left_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.shiftee,
expr.shift),
ast.LShift())
def map_right_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.RShift())
def map_bitwise_not(self, expr) -> ast.expr:
return ast.UnaryOp(ast.Invert(), self.rec(expr.child))
def map_bitwise_or(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitOr())
def map_bitwise_xor(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitXor())
def map_bitwise_and(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitAnd())
def map_logical_not(self, expr) -> ast.expr:
return ast.UnaryOp(self.rec(expr.child), ast.Not())
def map_logical_or(self, expr) -> ast.expr:
return ast.BoolOp(ast.Or(), [self.rec(child)
for child in expr.children])
def map_logical_and(self, expr) -> ast.expr:
return ast.BoolOp(ast.And(), [self.rec(child)
for child in expr.children])
def map_list(self, expr: List[Any]) -> ast.expr:
return ast.List([self.rec(el) for el in expr])
def map_tuple(self, expr: Tuple[Any, ...]) -> ast.expr:
return ast.Tuple([self.rec(el) for el in expr])
def map_if(self, expr: p.If) -> ast.expr:
return ast.IfExp(test=self.rec(expr.condition),
body=self.rec(expr.then),
orelse=self.rec(expr.else_))
def map_nan(self, expr: p.NaN) -> ast.expr:
if isinstance(expr.data_type(float("nan")), float):
return ast.Call(
ast.Name(id="float"),
args=[ast.Constant("nan")],
keywords=[])
else:
# TODO: would need attributes of NumPy
raise NotImplementedError("Non-float nan not implemented")
def map_slice(self, expr: p.Slice) -> ast.expr:
return ast.Slice(*[self.rec(child)
for child in expr.children])
def map_numpy_array(self, expr) -> ast.expr:
raise NotImplementedError
def map_multivector(self, expr) -> ast.expr:
raise NotImplementedError
def map_common_subexpression(self, expr) -> ast.expr:
raise NotImplementedError
def map_substitution(self, expr) -> ast.expr:
raise NotImplementedError
def map_derivative(self, expr) -> ast.expr:
raise NotImplementedError
def map_if_positive(self, expr) -> ast.expr:
raise NotImplementedError
def map_comparison(self, expr: p.Comparison) -> ast.expr:
raise NotImplementedError
def map_polynomial(self, expr) -> ast.expr:
raise NotImplementedError
def map_wildcard(self, expr) -> ast.expr:
raise NotImplementedError
def map_dot_wildcard(self, expr) -> ast.expr:
raise NotImplementedError
def map_star_wildcard(self, expr) -> ast.expr:
raise NotImplementedError
def map_function_symbol(self, expr) -> ast.expr:
raise NotImplementedError
def map_min(self, expr) -> ast.expr:
raise NotImplementedError
def map_max(self, expr) -> ast.expr:
raise NotImplementedError
def to_python_ast(expr) -> ast.expr:
"""
Maps *expr* to :class:`ast.expr`.
"""
return PymbolicToASTMapper()(expr)
def to_evaluatable_python_function(expr: ExpressionT,
fn_name: str
) -> str:
"""
Returns a :class:`str` of the python code with a single function *fn_name*
that takes in the variables in *expr* as keyword-only arguments and returns
the evaluated value of *expr*.
.. testsetup::
>>> from pymbolic import parse
>>> from pymbolic.interop.ast import to_evaluatable_python_function
.. doctest::
>>> expr = parse("S//32 + E%32")
>>> # Skipping doctest as astunparse and ast.unparse have certain subtle
>>> # differences
>>> print(to_evaluatable_python_function(expr, "foo"))) # doctest: +SKIP
def foo(*, E, S):
return S // 32 + E % 32
"""
import sys
from pymbolic.mapper.dependency import CachedDependencyMapper
if sys.version_info < (3, 9):
try:
from astunparse import unparse
except ImportError:
raise RuntimeError("'to_evaluate_python_function' needs"
"astunparse for Py<3.9. Install via `pip"
" install astunparse`")
else:
unparse = ast.unparse
dep_mapper = CachedDependencyMapper(composite_leaves=True)
deps = sorted({dep.name for dep in dep_mapper(expr)})
ast_func = ast.FunctionDef(name=fn_name,
args=ast.arguments(args=[],
posonlyargs=[],
kwonlyargs=[ast.arg(dep, None)
for dep in deps],
kw_defaults=[None]*len(deps),
vararg=None,
kwarg=None,
defaults=[]),
body=[ast.Return(to_python_ast(expr))],
decorator_list=[])
ast_module = ast.Module([ast_func], type_ignores=[])
return unparse(ast.fix_missing_locations(ast_module))
# }}}
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment