diff --git a/examples/imex/imex.py b/examples/imex/imex.py index a7752347f859826ccc6cf4c5dd7e9e5eeee0d976..16ca2303d4d3133dbadf8dd5569964dbb1883517 100644 --- a/examples/imex/imex.py +++ b/examples/imex/imex.py @@ -70,15 +70,16 @@ def solver(f, j, t, u_n, x, c): def solver_hook(solve_expr, solve_var, solver_id, guess): - from dagrt.expression import match - from leap.implicit import make_solver_call + from dagrt.expression import match, substitute pieces = match("unk - rhs(t=t, y=y + sub_y + coeff*unk)", solve_expr, bound_variable_names=["y"], pre_match={"unk": solve_var}) - return make_solver_call("solver(t, y, sub_y, coeff)", pieces) + pieces["guess"] = guess + + return substitute("solver(t, y, sub_y, coeff)", pieces) def run(): diff --git a/examples/implicit_euler/test_implicit_euler.py b/examples/implicit_euler/test_implicit_euler.py index 203c6ce22452b65d3ff6176ab90813c801b728ac..ad7da16cab126ddad76d4a2130ae9decd71741ef 100755 --- a/examples/implicit_euler/test_implicit_euler.py +++ b/examples/implicit_euler/test_implicit_euler.py @@ -48,11 +48,10 @@ def solver(f, t, h, y, guess): def solver_hook(expr, var, solver_id, guess): - from dagrt.expression import match - from leap.implicit import make_solver_call + from dagrt.expression import match, substitute pieces = match("unk-y-h*f(t=t,y=unk)", expr, pre_match={"unk": var}) - return make_solver_call("solver(t,h,y,guess)", - pieces, guess, guess_name="guess") + pieces["guess"] = guess + return substitute("solver(t,h,y,guess)", pieces) @pytest.mark.parametrize("python_method_impl", diff --git a/leap/implicit.py b/leap/implicit.py index 1ebfd7966eaf09f20a23a836dc222ff821dcf08d..02e98f1d727b5f98f6cc0814e0fce153381c4ee2 100644 --- a/leap/implicit.py +++ b/leap/implicit.py @@ -26,26 +26,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import six -def make_solver_call(template, pieces, guess=None, guess_name=None): - """ - :arg template: A template for a solver call - :arg pieces: A dictionary mapping variable names to subexpressions, to - substitute into the template - :arg guess: The expression for the initial guess - :arg guess_name: The variable name for the initial guess - """ - if isinstance(template, str): - from dagrt.expression import parse - template = parse(template) - from pymbolic import substitute - pieces.update({guess_name: guess}) - return substitute(template, pieces) - - def replace_AssignSolved(dag, solver_hooks): """ :arg dag: The :class:`DAGCode` instance diff --git a/test/test_imex.py b/test/test_imex.py index 19a0a0e436d2dd6f08df8be5ecc411898e25ad22..f8533516810bf5dfcf85c11d7546d42c11917ec3 100755 --- a/test/test_imex.py +++ b/test/test_imex.py @@ -42,14 +42,12 @@ def solver(f, t, sub_y, coeff, guess): def solver_hook(solve_expr, solve_var, solver_id, guess): - from dagrt.expression import match - from leap.implicit import make_solver_call + from dagrt.expression import match, substitute pieces = match("unk - rhs(t=t, y=sub_y + coeff*unk)", solve_expr, pre_match={"unk": solve_var}) - return make_solver_call("solver(t, sub_y, coeff, guess)", - pieces, - guess, guess_name="guess") + pieces["guess"] = guess + return substitute("solver(t, sub_y, coeff, guess)", pieces) @pytest.mark.parametrize("problem, method, expected_order", [