diff --git a/dagrt/expression.py b/dagrt/expression.py index 63f1b7c5d1d0e37df6e21354c2c035ea7aa82422..e47a94f6b201f243d9379ed9f626ba0ea72e1340 100644 --- a/dagrt/expression.py +++ b/dagrt/expression.py @@ -319,37 +319,64 @@ class _ExtendedUnifier(UnidirectionalUnifier): def match(template, expression, free_variable_names=None, - bound_variable_names=None): - """ - Attempt to match the free variables found in `template` to terms in + bound_variable_names=None, pre_match=None): + """Attempt to match the free variables found in `template` to terms in `expression`, modulo associativity and commutativity. This implements a one-way unification algorithm, matching free variables in `template` to subexpressions of `expression`. + If `free_variable_names` is *None*, then all variables except those in + `bound_variable_names` are treated as free. + + Matches that are already known to hold can be specified in `pre_match`, a + map from variable names to subexpressions (or strings representing + subexpressions). + Return a map from variable names in `free_variable_names` to expressions. """ if isinstance(template, str): template = parse(template) + if isinstance(expression, str): expression = parse(expression) + if bound_variable_names is None: bound_variable_names = set() + if free_variable_names is None: from dagrt.utils import get_variables free_variable_names = get_variables( template, include_function_symbols=True) free_variable_names -= set(bound_variable_names) + + urecs = None + if pre_match is not None: + eqns = [] + for name, expr in six.iteritems(pre_match): + if name not in free_variable_names: + raise ValueError( + "'%s' was given in 'pre_match' but is " + "not a candidate for matching" % name) + if isinstance(expr, str): + expr = parse(expr) + eqns.append((Variable(name), expr)) + from pymbolic.mapper.unifier import UnificationRecord + urecs = [UnificationRecord(eqns)] + unifier = _ExtendedUnifier(free_variable_names) - records = unifier(template, expression) + records = unifier(template, expression, urecs) + if len(records) > 1: from warnings import warn warn("Matching\n\"{expr}\"\nto\n\"{template}\"\n" "is ambiguous - using first match".format( expr=expression, template=template)) + if not records: raise ValueError("Cannot unify expressions.") + return dict((key.name, val) for key, val in records[0].equations) diff --git a/test/test_expressions.py b/test/test_expressions.py index 92d04a9f2830f58e3566cc00f851b4a6f3221a2e..84b04dd7af795911033c56d0011ba2a31d2a901e 100755 --- a/test/test_expressions.py +++ b/test/test_expressions.py @@ -24,6 +24,9 @@ THE SOFTWARE. """ +import pytest + + def test_collapse_constants(): from pymbolic import var f = var("f") @@ -96,6 +99,22 @@ def test_match_modulo_identity(): assert subst["b"] == 0 +def test_match_with_pre_match(): + a, b, c, d = declare("a", "b", "c", "d") + from dagrt.expression import match + subst = match(a + b, c + d, ["a", "b"], pre_match={"a": "c"}) + + assert subst["a"] == c + assert subst["b"] == d + + +def test_match_with_pre_match_invalid_arg(): + a, b, c, d = declare("a", "b", "c", "d") + from dagrt.expression import match + with pytest.raises(ValueError): + match(a + b, c + d, ["a"], pre_match={"b": "c"}) + + def test_get_variables(): from pymbolic import var f = var('f')