From 06d600340189a4ebfbb1edbd8da40f531b43138c Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 13 Mar 2018 19:26:26 -0500 Subject: [PATCH] Add dagrt.expression.substitute(). This acts exactly like pymbolic's substitute(), but it takes a string argument which can be fed through dagrt's extended parser. Related: leap#121 --- dagrt/expression.py | 22 ++++++++++++++++++++-- test/test_expressions.py | 7 +++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/dagrt/expression.py b/dagrt/expression.py index e47a94f..9c3b83e 100644 --- a/dagrt/expression.py +++ b/dagrt/expression.py @@ -414,8 +414,10 @@ class _ExtendedParser(Parser): def parse(expr): - """Return a pymbolic expression constructed from the string. Values - between backticks ("`") are parsed as variable names. + """Return a pymbolic expression constructed from the string. + + Values between backticks ("`") are parsed as variable names. + Tagged identifiers ("f") are also parsed as variable names. """ from pymbolic import var @@ -433,4 +435,20 @@ def parse(expr): return substitutor(parser(expr)) +def substitute(expression, variable_assignments={}, **kwargs): + """Perform variable substitution. + + :arg expression: A string or :mod:`pymbolic` expression. + If a string, it will be parsed with :func:`parse`. + :arg variable_assignments: Mapping from variable names to expressions + :arg kwargs: Extra arguments passed to to :func:`pymbolic.substitute` + """ + from pymbolic import substitute as substitute_pymbolic + + if isinstance(expression, str): + expression = parse(expression) + + return substitute_pymbolic(expression, variable_assignments, **kwargs) + + # vim: foldmethod=marker diff --git a/test/test_expressions.py b/test/test_expressions.py index 84b04dd..97b9e81 100755 --- a/test/test_expressions.py +++ b/test/test_expressions.py @@ -133,6 +133,13 @@ def test_get_variables_with_function_symbols(): frozenset(['f', 'x']) +def test_substitute(): + f, a = declare("f", "a") + + from dagrt.expression import substitute + assert substitute("f(y)", {"y": a}) == f(a) + + # {{{ parser def test_parser(): -- GitLab