From 0543fceecd7f44ce6d7323f0f5259dba5e774e9d Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 13 Mar 2018 01:53:53 -0500 Subject: [PATCH 1/3] Implement a "pre_match" argument which allows pre-specifying free variable matches in dagrt.expression.match(). --- dagrt/expression.py | 31 +++++++++++++++++++++++++++---- test/test_expressions.py | 9 +++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/dagrt/expression.py b/dagrt/expression.py index 63f1b7c..8ac7c04 100644 --- a/dagrt/expression.py +++ b/dagrt/expression.py @@ -319,37 +319,60 @@ class _ExtendedUnifier(UnidirectionalUnifier): def match(template, expression, free_variable_names=None, - bound_variable_names=None): - """ - Attempt to match the free variables found in `template` to terms in + bound_variable_names=None, pre_match=None): + """Attempt to match the free variables found in `template` to terms in `expression`, modulo associativity and commutativity. This implements a one-way unification algorithm, matching free variables in `template` to subexpressions of `expression`. + If `free_variable_names` is *None*, then all variables except those in + `bound_variable_names` are treated as free. + + Matches that are already known to hold can be specified in `pre_match`, a + map from variable names to subexpression (or strings representing + subexpressions). + Return a map from variable names in `free_variable_names` to expressions. """ if isinstance(template, str): template = parse(template) + if isinstance(expression, str): expression = parse(expression) + if bound_variable_names is None: bound_variable_names = set() + if free_variable_names is None: from dagrt.utils import get_variables free_variable_names = get_variables( template, include_function_symbols=True) free_variable_names -= set(bound_variable_names) + + urecs = None + if pre_match is not None: + eqns = [] + for name, expr in six.iteritems(pre_match): + if isinstance(expr, str): + expr = parse(expr) + eqns.append((Variable(name), expr)) + from pymbolic.mapper.unifier import UnificationRecord + urecs = [UnificationRecord(eqns)] + unifier = _ExtendedUnifier(free_variable_names) - records = unifier(template, expression) + records = unifier(template, expression, urecs) + if len(records) > 1: from warnings import warn warn("Matching\n\"{expr}\"\nto\n\"{template}\"\n" "is ambiguous - using first match".format( expr=expression, template=template)) + if not records: raise ValueError("Cannot unify expressions.") + return dict((key.name, val) for key, val in records[0].equations) diff --git a/test/test_expressions.py b/test/test_expressions.py index 92d04a9..395e9e4 100755 --- a/test/test_expressions.py +++ b/test/test_expressions.py @@ -96,6 +96,15 @@ def test_match_modulo_identity(): assert subst["b"] == 0 +def test_match_with_pre_match(): + a, b, c, d = declare("a", "b", "c", "d") + from dagrt.expression import match + subst = match(a + b, c + d, ["a", "b"], pre_match={"a": "c"}) + + assert subst["a"] == c + assert subst["b"] == d + + def test_get_variables(): from pymbolic import var f = var('f') -- GitLab From 443cbe95b5b34e2070cbeae78e1a8789487a7541 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 13 Mar 2018 01:58:45 -0500 Subject: [PATCH 2/3] Error out if pre_match has names not in the set of free variables --- dagrt/expression.py | 4 ++++ test/test_expressions.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/dagrt/expression.py b/dagrt/expression.py index 8ac7c04..2294848 100644 --- a/dagrt/expression.py +++ b/dagrt/expression.py @@ -355,6 +355,10 @@ def match(template, expression, free_variable_names=None, if pre_match is not None: eqns = [] for name, expr in six.iteritems(pre_match): + if name not in free_variable_names: + raise ValueError( + "'%s' was given in 'pre_match' but is " + "not a candidate for matching" % name) if isinstance(expr, str): expr = parse(expr) eqns.append((Variable(name), expr)) diff --git a/test/test_expressions.py b/test/test_expressions.py index 395e9e4..84b04dd 100755 --- a/test/test_expressions.py +++ b/test/test_expressions.py @@ -24,6 +24,9 @@ THE SOFTWARE. """ +import pytest + + def test_collapse_constants(): from pymbolic import var f = var("f") @@ -105,6 +108,13 @@ def test_match_with_pre_match(): assert subst["b"] == d +def test_match_with_pre_match_invalid_arg(): + a, b, c, d = declare("a", "b", "c", "d") + from dagrt.expression import match + with pytest.raises(ValueError): + match(a + b, c + d, ["a"], pre_match={"b": "c"}) + + def test_get_variables(): from pymbolic import var f = var('f') -- GitLab From 9b47d48703f482ec566f797591034278860f502a Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 13 Mar 2018 02:03:12 -0500 Subject: [PATCH 3/3] Typo fix --- dagrt/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagrt/expression.py b/dagrt/expression.py index 2294848..e47a94f 100644 --- a/dagrt/expression.py +++ b/dagrt/expression.py @@ -330,7 +330,7 @@ def match(template, expression, free_variable_names=None, `bound_variable_names` are treated as free. Matches that are already known to hold can be specified in `pre_match`, a - map from variable names to subexpression (or strings representing + map from variable names to subexpressions (or strings representing subexpressions). Return a map from variable names in `free_variable_names` to -- GitLab