diff --git a/dagrt/expression.py b/dagrt/expression.py index e47a94f6b201f243d9379ed9f626ba0ea72e1340..9c3b83e7e6e8943863ff16bb5c7323ef321eb252 100644 --- a/dagrt/expression.py +++ b/dagrt/expression.py @@ -414,8 +414,10 @@ class _ExtendedParser(Parser): def parse(expr): - """Return a pymbolic expression constructed from the string. Values - between backticks ("`") are parsed as variable names. + """Return a pymbolic expression constructed from the string. + + Values between backticks ("`") are parsed as variable names. + Tagged identifiers ("f") are also parsed as variable names. """ from pymbolic import var @@ -433,4 +435,20 @@ def parse(expr): return substitutor(parser(expr)) +def substitute(expression, variable_assignments={}, **kwargs): + """Perform variable substitution. + + :arg expression: A string or :mod:`pymbolic` expression. + If a string, it will be parsed with :func:`parse`. + :arg variable_assignments: Mapping from variable names to expressions + :arg kwargs: Extra arguments passed to to :func:`pymbolic.substitute` + """ + from pymbolic import substitute as substitute_pymbolic + + if isinstance(expression, str): + expression = parse(expression) + + return substitute_pymbolic(expression, variable_assignments, **kwargs) + + # vim: foldmethod=marker diff --git a/test/test_expressions.py b/test/test_expressions.py index 84b04dd7af795911033c56d0011ba2a31d2a901e..97b9e8101aaff137f928aaa4eefb0b0664e74e72 100755 --- a/test/test_expressions.py +++ b/test/test_expressions.py @@ -133,6 +133,13 @@ def test_get_variables_with_function_symbols(): frozenset(['f', 'x']) +def test_substitute(): + f, a = declare("f", "a") + + from dagrt.expression import substitute + assert substitute("f(y)", {"y": a}) == f(a) + + # {{{ parser def test_parser():