diff --git a/test/test_firedrake_interop.py b/test/test_firedrake_interop.py index 2e50365fe952bdcf0cfb3ae815c0172d22e996a4..57c62723a49ea1de16877be1587a8d482ced448d 100644 --- a/test/test_firedrake_interop.py +++ b/test/test_firedrake_interop.py @@ -300,8 +300,8 @@ def test_bdy_tags(square_or_cube_mesh, bdy_ids, coord_indices, coord_values, # }}} -# TODO : Add function transfer test for ToFiredrakeConnection # TODO : Add idempotency test for ToFiredrakeConnection +# TODO : Add test for ToFiredrakeConnection where group_nr != 0 # {{{ Double check functions are being transported correctly def alternating_sum_fd(spatial_coord): @@ -338,20 +338,22 @@ test_functions = [ @pytest.mark.parametrize("fdrake_f_expr,meshmode_f_eval", test_functions) @pytest.mark.parametrize("only_convert_bdy", (False, True)) -def test_function_transfer(ctx_factory, - fdrake_mesh, fdrake_family, fspace_degree, - fdrake_f_expr, meshmode_f_eval, - only_convert_bdy): +def test_from_fd_transfer(ctx_factory, + fdrake_mesh, fdrake_family, fspace_degree, + fdrake_f_expr, meshmode_f_eval, + only_convert_bdy): """ Make sure creating a function then transporting it is the same (up to resampling error) as creating a function on the transported mesh """ + # make function space and function fdrake_fspace = FunctionSpace(fdrake_mesh, fdrake_family, fspace_degree) spatial_coord = SpatialCoordinate(fdrake_mesh) fdrake_f = Function(fdrake_fspace).interpolate(fdrake_f_expr(spatial_coord)) + # build connection cl_ctx = ctx_factory() if only_convert_bdy: fdrake_connection = FromBdyFiredrakeConnection(cl_ctx, fdrake_fspace, @@ -359,14 +361,56 @@ def test_function_transfer(ctx_factory, else: fdrake_connection = FromFiredrakeConnection(cl_ctx, fdrake_fspace) - transported_f = fdrake_connection.from_firedrake(fdrake_f) + # transport fdrake function + fd2mm_f = fdrake_connection.from_firedrake(fdrake_f) + # build same function in meshmode discr = fdrake_connection.discr with cl.CommandQueue(cl_ctx) as queue: nodes = discr.nodes().get(queue=queue) meshmode_f = meshmode_f_eval(nodes) - np.testing.assert_allclose(transported_f, meshmode_f, atol=CLOSE_ATOL) + # fd -> mm should be same as creating in meshmode + np.testing.assert_allclose(fd2mm_f, meshmode_f, atol=CLOSE_ATOL) + + if not only_convert_bdy: + # now transport mm -> fd + mm2fd_f = \ + fdrake_connection.from_meshmode(meshmode_f, + assert_fdrake_discontinuous=False, + continuity_tolerance=1e-8) + # mm -> fd should be same as creating in firedrake + np.testing.assert_allclose(fdrake_f.dat.data, mm2fd_f.dat.data, + atol=CLOSE_ATOL) + + +@pytest.mark.parametrize("fdrake_f_expr,meshmode_f_eval", test_functions) +def test_to_fd_transfer(ctx_factory, mm_mesh, fspace_degree, + fdrake_f_expr, meshmode_f_eval): + """ + Make sure creating a function then transporting it is the same + (up to resampling error) as creating a function on the transported + mesh + """ + # Make discr and evaluate function in meshmode + cl_ctx = ctx_factory() + factory = InterpolatoryQuadratureSimplexGroupFactory(fspace_degree) + discr = Discretization(cl_ctx, mm_mesh, factory) + + with cl.CommandQueue(cl_ctx) as queue: + nodes = discr.nodes().get(queue=queue) + meshmode_f = meshmode_f_eval(nodes) + + # connect to firedrake and evaluate expr in firedrake + fdrake_connection = ToFiredrakeConnection(discr) + fdrake_fspace = fdrake_connection.firedrake_fspace() + spatial_coord = SpatialCoordinate(fdrake_fspace.mesh()) + fdrake_f = Function(fdrake_fspace).interpolate(fdrake_f_expr(spatial_coord)) + + # transport to firedrake and make sure this is the same + mm2fd_f = fdrake_connection.from_meshmode(meshmode_f) + np.testing.assert_allclose(mm2fd_f.dat.data, fdrake_f.dat.data, + atol=CLOSE_ATOL) # }}}