diff --git a/meshmode/discretization/connection/chained.py b/meshmode/discretization/connection/chained.py index 13a70e004e7f8e6b2535266dfab21615e296c0af..13a098d3a51afa34a6a2b9f11f1227560097ddb5 100644 --- a/meshmode/discretization/connection/chained.py +++ b/meshmode/discretization/connection/chained.py @@ -158,10 +158,10 @@ def flatten_chained_connection(queue, connection): DiscretizationConnectionElementGroup, make_same_mesh_connection) - if isinstance(connection, DirectDiscretizationConnection): + if not hasattr(connection, 'connections'): return connection - if not hasattr(connection, 'connections') or not connection.connections: + if not connection.connections: return make_same_mesh_connection(connection.to_discr, connection.from_discr) diff --git a/test/test_chained.py b/test/test_chained.py index 450a76ecb09993261db82c2dfe3e2513f5d2b76a..112e1c6e27405eea0d6d9573fb74b3405aa77dbb 100644 --- a/test/test_chained.py +++ b/test/test_chained.py @@ -43,7 +43,7 @@ def create_discretization(queue, ndim, discr_order=5): ctx = queue.context - # construct base mesh + # construct mesh if ndim == 2: from functools import partial from meshmode.mesh.generation import make_curve_mesh, ellipse @@ -56,7 +56,7 @@ def create_discretization(queue, ndim, else: raise ValueError("Unsupported dimension: {}".format(ndim)) - # create base discretization + # create discretization from meshmode.discretization import Discretization from meshmode.discretization.poly_element import \ InterpolatoryQuadratureSimplexGroupFactory @@ -98,6 +98,7 @@ def create_face_connection(queue, discr): return connection +@pytest.mark.skip(reason='implementation detail') @pytest.mark.parametrize("ndim", [2, 3]) def test_chained_batch_table(ctx_factory, ndim, visualize=False): from meshmode.discretization.connection.chained import \ @@ -129,6 +130,7 @@ def test_chained_batch_table(ctx_factory, ndim, visualize=False): assert k == batch.from_element_indices[i] +@pytest.mark.skip(reason='implementation detail') @pytest.mark.parametrize("ndim", [2, 3]) def test_chained_new_group_table(ctx_factory, ndim, visualize=False): from meshmode.discretization.connection.chained import \ @@ -178,6 +180,38 @@ def test_chained_new_group_table(ctx_factory, ndim, visualize=False): pt.clf() +@pytest.mark.parametrize("ndim", [2, 3]) +def test_chained_connection(ctx_factory, ndim, visualize=False): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + discr = create_discretization(queue, ndim, + nelements=10, + mesh_order=5, + discr_order=5) + connections = [] + conn = create_refined_connection(queue, discr, threshold=np.inf) + connections.append(conn) + conn = create_refined_connection(queue, conn.to_discr, threshold=np.inf) + connections.append(conn) + + from meshmode.discretization.connection import \ + ChainedDiscretizationConnection + chained = ChainedDiscretizationConnection(connections) + + def f(x): + from six.moves import reduce + return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x) + + x = connections[0].from_discr.nodes().with_queue(queue) + fx = f(x) + f1 = chained(queue, fx).get(queue) + f2 = connections[1](queue, connections[0](queue, fx)).get(queue) + + assert np.allclose(f1, f2) + + +@pytest.mark.slowtest @pytest.mark.parametrize("ndim", [2, 3]) def test_chained_full_resample_matrix(ctx_factory, ndim, visualize=False): from meshmode.discretization.connection.chained import \ @@ -195,7 +229,8 @@ def test_chained_full_resample_matrix(ctx_factory, ndim, visualize=False): conn = create_refined_connection(queue, conn.to_discr) connections.append(conn) - from meshmode.discretization.connection import ChainedDiscretizationConnection + from meshmode.discretization.connection import \ + ChainedDiscretizationConnection chained = ChainedDiscretizationConnection(connections) def f(x): @@ -246,7 +281,8 @@ def test_chained_to_direct(ctx_factory, ndim, chain_type, visualize=False): else: raise ValueError('unknown test case') - from meshmode.discretization.connection import ChainedDiscretizationConnection + from meshmode.discretization.connection import \ + ChainedDiscretizationConnection chained = ChainedDiscretizationConnection(connections) direct = flatten_chained_connection(queue, chained) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 7173fc29d359b9a6e7ef65f889f819e3290a455a..29dc2ed1075aa31621e9d4d5ac1578c9732f8bd8 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1039,65 +1039,6 @@ def test_quad_multi_element(): plt.show() -# {{{ ChainedDiscretizationConnection - -def test_ChainedDiscretizationConnection(ctx_getter): # noqa - mesh_order = 5 - order = 5 - npanels = 10 - group_factory = InterpolatoryQuadratureSimplexGroupFactory - - def refine_flags(mesh): - return np.ones(mesh.nelements) - - cl_ctx = ctx_getter() - queue = cl.CommandQueue(cl_ctx) - - from functools import partial - from meshmode.discretization import Discretization - from meshmode.discretization.connection import make_refinement_connection - from meshmode.mesh.generation import make_curve_mesh, ellipse - - mesh = make_curve_mesh( - partial(ellipse, 1), np.linspace(0, 1, npanels + 1), - order=mesh_order) - - discr = Discretization(cl_ctx, mesh, group_factory(order)) - - connections = [] - - from meshmode.mesh.refinement import Refiner - refiner = Refiner(mesh) - - def refine_discr(discr): - mesh = discr.mesh - flags = refine_flags(mesh) - refiner.refine(flags) - connections.append( - make_refinement_connection(refiner, discr, group_factory(order))) - return connections[-1].to_discr - - discr = refine_discr(discr) - refine_discr(discr) - - from meshmode.discretization.connection import ( - ChainedDiscretizationConnection) - - chained_conn = ChainedDiscretizationConnection(connections) - - def f(x): - from six.moves import reduce - return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x) - - x = connections[0].from_discr.nodes().with_queue(queue) - - assert np.allclose( - chained_conn(queue, f(x)).get(queue), - connections[1](queue, connections[0](queue, f(x))).get(queue)) - -# }}} - - if __name__ == "__main__": import sys if len(sys.argv) > 1: