From 9913b3311b38f33aff671a2cede22a28a165fc8c Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 13 Mar 2018 18:55:38 -0500 Subject: [PATCH] replace_AssignSolved(): Allow passing a callable for solver_hooks. --- leap/implicit.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/leap/implicit.py b/leap/implicit.py index 26d18f6..1ebfd79 100644 --- a/leap/implicit.py +++ b/leap/implicit.py @@ -27,6 +27,9 @@ THE SOFTWARE. """ +import six + + def make_solver_call(template, pieces, guess=None, guess_name=None): """ :arg template: A template for a solver call @@ -46,8 +49,9 @@ def make_solver_call(template, pieces, guess=None, guess_name=None): def replace_AssignSolved(dag, solver_hooks): """ :arg dag: The :class:`DAGCode` instance - :arg solver_hooks: A map from solver names to functions that generate solver - calls. + :arg solver_hooks: Either a callable, or a map from solver names to + functions that generate solver calls. + A solver hook should have the signature:: def solver_hook(expr, var, id, **kwargs): @@ -56,9 +60,14 @@ def replace_AssignSolved(dag, solver_hooks): * *expr* is the expression passed to the AssignSolved instruction * *var* is the name of the unknown * *id* is the *solver_id* field of the AssignSolved instruction - * any other arguments are passed in *kwargs* + * any other arguments are passed in *kwargs*. """ + if six.callable(solver_hooks): + hook = solver_hooks + from collections import defaultdict + solver_hooks = defaultdict(lambda: hook) + new_statements = [] from dagrt.language import AssignExpression, AssignSolved -- GitLab