diff --git a/meshmode/discretization/connection/__init__.py b/meshmode/discretization/connection/__init__.py index da982c9d94868d2f9f1f5ef386f2fcaeb7b33342..85e41d89f832f0e410e0c3a208ae708724f43f82 100644 --- a/meshmode/discretization/connection/__init__.py +++ b/meshmode/discretization/connection/__init__.py @@ -218,6 +218,7 @@ class ChainedDiscretizationConnection(DiscretizationConnection): else: from_discr = connections[0].from_discr is_surjective = all(cnx.is_surjective for cnx in connections) + to_discr = connections[-1].to_discr else: if from_discr is None: raise ValueError("connections may not be empty if from_discr " diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 2cb0470a0656d5d46242efdea2e23c9df2d31f87..958130acea72b706722f1cee26045931f61665ec 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1036,6 +1036,64 @@ def test_quad_multi_element(): plt.show() +# {{{ ChainedDiscretizationConnection + +def test_ChainedDiscretizationConnection(ctx_getter): # noqa + mesh_order = 5 + order = 5 + npanels = 10 + group_factory = InterpolatoryQuadratureSimplexGroupFactory + + def refine_flags(mesh): + return np.ones(mesh.nelements) + + cl_ctx = ctx_getter() + queue = cl.CommandQueue(cl_ctx) + + from functools import partial + from meshmode.discretization import Discretization + from meshmode.discretization.connection import make_refinement_connection + from meshmode.mesh.generation import make_curve_mesh, ellipse + + mesh = make_curve_mesh( + partial(ellipse, 1), np.linspace(0, 1, npanels + 1), + order=mesh_order) + + discr = Discretization(cl_ctx, mesh, group_factory(order)) + + connections = [] + + def refine_discr(discr): + mesh = discr.mesh + from meshmode.mesh.refinement import Refiner + refiner = Refiner(mesh) + flags = refine_flags(mesh) + refiner.refine(flags) + connections.append( + make_refinement_connection(refiner, discr, group_factory(order))) + return connections[-1].to_discr + + discr = refine_discr(discr) + refine_discr(discr) + + from meshmode.discretization.connection import ( + ChainedDiscretizationConnection) + + chained_conn = ChainedDiscretizationConnection(connections) + + def f(x): + from six.moves import reduce + return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x) + + x = connections[0].from_discr.nodes().with_queue(queue) + + assert np.allclose( + chained_conn(queue, f(x)).get(queue), + connections[1](queue, connections[0](queue, f(x))).get(queue)) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: