diff --git a/examples/imex/imex.py b/examples/imex/imex.py index 57fb464f9b5740a2f2236a64ffd93b5d8e0c587e..a7752347f859826ccc6cf4c5dd7e9e5eeee0d976 100644 --- a/examples/imex/imex.py +++ b/examples/imex/imex.py @@ -69,13 +69,15 @@ def solver(f, j, t, u_n, x, c): return f(t=t, y=u) -def solver_hook(solve_expr, guess): +def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match from leap.implicit import make_solver_call pieces = match("unk - rhs(t=t, y=y + sub_y + coeff*unk)", solve_expr, - bound_variable_names=["y"]) + bound_variable_names=["y"], + pre_match={"unk": solve_var}) + return make_solver_call("solver(t, y, sub_y, coeff)", pieces) diff --git a/examples/implicit_euler/test_implicit_euler.py b/examples/implicit_euler/test_implicit_euler.py index 8066452bee23f7ce87ab2a1062ee36e2fed5b79c..203c6ce22452b65d3ff6176ab90813c801b728ac 100755 --- a/examples/implicit_euler/test_implicit_euler.py +++ b/examples/implicit_euler/test_implicit_euler.py @@ -47,10 +47,10 @@ def solver(f, t, h, y, guess): return newton(lambda u: u-y-h*f(t=t, y=u), guess) -def solver_hook(expr, guess): +def solver_hook(expr, var, solver_id, guess): from dagrt.expression import match from leap.implicit import make_solver_call - pieces = match("unk-y-h*f(t=t,y=unk)", expr) + pieces = match("unk-y-h*f(t=t,y=unk)", expr, pre_match={"unk": var}) return make_solver_call("solver(t,h,y,guess)", pieces, guess, guess_name="guess") diff --git a/leap/implicit.py b/leap/implicit.py index 0527066f936d9b491e7dc7e89b1e024fac11fa53..26d18f634d3e8ee382da10269332dea507a40da4 100644 --- a/leap/implicit.py +++ b/leap/implicit.py @@ -46,7 +46,17 @@ def make_solver_call(template, pieces, guess=None, guess_name=None): def replace_AssignSolved(dag, solver_hooks): """ :arg dag: The :class:`DAGCode` instance - :arg solver_hooks: A map from solver names to expression generators + :arg solver_hooks: A map from solver names to functions that generate solver + calls. + A solver hook should have the signature:: + + def solver_hook(expr, var, id, **kwargs): + + where: + * *expr* is the expression passed to the AssignSolved instruction + * *var* is the name of the unknown + * *id* is the *solver_id* field of the AssignSolved instruction + * any other arguments are passed in *kwargs* """ new_statements = [] @@ -68,15 +78,19 @@ def replace_AssignSolved(dag, solver_hooks): "returning multiple values.") expression = stmt.expressions[0] + solve_variable = stmt.solve_variables[0] + solver_id = stmt.solver_id other_params = stmt.other_params - solver = solver_hooks[stmt.solver_id] + solver_hook = solver_hooks[stmt.solver_id] + solver_expression = solver_hook(expression, solve_variable, + solver_id, **other_params) new_statements.append( AssignExpression( assignee=stmt.assignees[0], assignee_subscript=(), - expression=solver(expression, **other_params), + expression=solver_expression, id=stmt.id, condition=stmt.condition, depends_on=stmt.depends_on)) diff --git a/test/test_imex.py b/test/test_imex.py index 52676e1d6c1feb061715a1f5a59d0a8ef9476bf1..19a0a0e436d2dd6f08df8be5ecc411898e25ad22 100755 --- a/test/test_imex.py +++ b/test/test_imex.py @@ -41,11 +41,12 @@ def solver(f, t, sub_y, coeff, guess): return root(lambda unk: unk - f(t=t, y=sub_y + coeff*unk), guess).x -def solver_hook(solve_expr, guess): +def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match from leap.implicit import make_solver_call - pieces = match("unk - rhs(t=t, y=sub_y + coeff*unk)", solve_expr) + pieces = match("unk - rhs(t=t, y=sub_y + coeff*unk)", solve_expr, + pre_match={"unk": solve_var}) return make_solver_call("solver(t, sub_y, coeff, guess)", pieces, guess, guess_name="guess")