From c052b74ed67f1269a301ec80624605632de3b2d8 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 13 Mar 2018 02:21:59 -0500 Subject: [PATCH] replace_AssignSolved(): Pass more data to the solver hook. Closes #112 --- examples/imex/imex.py | 6 ++++-- .../implicit_euler/test_implicit_euler.py | 4 ++-- leap/implicit.py | 20 ++++++++++++++++--- test/test_imex.py | 5 +++-- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/examples/imex/imex.py b/examples/imex/imex.py index 57fb464..a775234 100644 --- a/examples/imex/imex.py +++ b/examples/imex/imex.py @@ -69,13 +69,15 @@ def solver(f, j, t, u_n, x, c): return f(t=t, y=u) -def solver_hook(solve_expr, guess): +def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match from leap.implicit import make_solver_call pieces = match("unk - rhs(t=t, y=y + sub_y + coeff*unk)", solve_expr, - bound_variable_names=["y"]) + bound_variable_names=["y"], + pre_match={"unk": solve_var}) + return make_solver_call("solver(t, y, sub_y, coeff)", pieces) diff --git a/examples/implicit_euler/test_implicit_euler.py b/examples/implicit_euler/test_implicit_euler.py index 313b016..541a2c8 100755 --- a/examples/implicit_euler/test_implicit_euler.py +++ b/examples/implicit_euler/test_implicit_euler.py @@ -47,10 +47,10 @@ def solver(f, t, h, y, guess): return newton(lambda u: u-y-h*f(t=t, y=u), guess) -def solver_hook(expr, guess): +def solver_hook(expr, var, solver_id, guess): from dagrt.expression import match from leap.implicit import make_solver_call - pieces = match("unk-y-h*f(t=t,y=unk)", expr) + pieces = match("unk-y-h*f(t=t,y=unk)", expr, pre_match={"unk": var}) return make_solver_call("solver(t,h,y,guess)", pieces, guess, guess_name="guess") diff --git a/leap/implicit.py b/leap/implicit.py index f7d891e..610b25a 100644 --- a/leap/implicit.py +++ b/leap/implicit.py @@ -46,7 +46,17 @@ def make_solver_call(template, pieces, guess=None, guess_name=None): def replace_AssignSolved(dag, solver_hooks): """ :arg dag: The :class:`DAGCode` instance - :arg solver_hooks: A map from solver names to expression generators + :arg solver_hooks: A map from solver names to functions that generate solver + calls. + A solver hook should have the signature:: + + def solver_hook(expr, var, id, **kwargs): + + where: + * *expr* is the expression passed to the AssignSolved instruction + * *var* is the name of the unknown + * *id* is the *solver_id* field of the AssignSolved instruction + * any other arguments are passed in *kwargs* """ new_instructions = [] @@ -68,15 +78,19 @@ def replace_AssignSolved(dag, solver_hooks): "returning multiple values.") expression = insn.expressions[0] + solve_variable = insn.solve_variables[0] + solver_id = insn.solver_id other_params = insn.other_params - solver = solver_hooks[insn.solver_id] + solver_hook = solver_hooks[insn.solver_id] + solver_expression = solver_hook(expression, solve_variable, + solver_id, **other_params) new_instructions.append( AssignExpression( assignee=insn.assignees[0], assignee_subscript=(), - expression=solver(expression, **other_params), + expression=solver_expression, id=insn.id, condition=insn.condition, depends_on=insn.depends_on)) diff --git a/test/test_imex.py b/test/test_imex.py index 52676e1..19a0a0e 100755 --- a/test/test_imex.py +++ b/test/test_imex.py @@ -41,11 +41,12 @@ def solver(f, t, sub_y, coeff, guess): return root(lambda unk: unk - f(t=t, y=sub_y + coeff*unk), guess).x -def solver_hook(solve_expr, guess): +def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match from leap.implicit import make_solver_call - pieces = match("unk - rhs(t=t, y=sub_y + coeff*unk)", solve_expr) + pieces = match("unk - rhs(t=t, y=sub_y + coeff*unk)", solve_expr, + pre_match={"unk": solve_var}) return make_solver_call("solver(t, sub_y, coeff, guess)", pieces, guess, guess_name="guess") -- GitLab