diff --git a/meshmode/discretization/connection/opposite_face.py b/meshmode/discretization/connection/opposite_face.py index bfed445c5a5efac8167b5b162e9bc9269576fbab..925bfe0d1fea79aeb7e5efb82834c94b3a229c25 100644 --- a/meshmode/discretization/connection/opposite_face.py +++ b/meshmode/discretization/connection/opposite_face.py @@ -393,6 +393,8 @@ def make_opposite_face_connection(volume_to_bdry_conn): # }}} +# {{{ partition_connection + def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj, i_tgt_part, i_tgt_grp, i_tgt_elem, i_tgt_face): """ @@ -555,7 +557,7 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj, to_element_face=None) -def make_partition_connection(vol_to_bdry_conns): +def make_partition_connection(vol_to_bdry_conns, adj_parts): """ Given a list of boundary restriction connections *volume_to_bdry_conn*, return a :class:`DirectDiscretizationConnection` that performs data @@ -572,27 +574,24 @@ def make_partition_connection(vol_to_bdry_conns): from meshmode.discretization.connection import ( DirectDiscretizationConnection, DiscretizationConnectionElementGroup) - # My intuition tells me that this should not live inside a for loop. - # However, I need to grab a cl_context. I'll assume that each context from - # each partition is the same and I'll use the first one. - cl_context = vol_to_bdry_conns[0].from_discr.cl_context - with cl.CommandQueue(cl_context) as queue: - # Create a list of batches. Each batch contains interpolation - # data from one partition to another. - for i_tgt_part, tgt_vol_conn in enumerate(vol_to_bdry_conns): + # Create a list of batches. Each batch contains interpolation + # data from one partition to another. + for i_tgt_part, tgt_vol_conn in enumerate(vol_to_bdry_conns): + + # Is this ok in a for loop? + cl_context = tgt_vol_conn.from_discr.cl_context + with cl.CommandQueue(cl_context) as queue: + bdry_discr = tgt_vol_conn.to_discr - tgt_mesh = tgt_vol_conn.to_discr.mesh + tgt_mesh = bdry_discr.mesh ngroups = len(tgt_mesh.groups) part_batches = [[] for _ in range(ngroups)] - for tgt_group_num, adj in enumerate(tgt_mesh.interpart_adj_groups): + # Hack, I need to get InterPartitionAdj so I'll receive it directly + # as an argument. + for tgt_group_num, adj in enumerate(adj_parts[i_tgt_part]): for idx, tgt_elem in enumerate(adj.elements): tgt_face = adj.element_faces[idx] - # We need to create a batch using the - # neighboring face, element, and group - # I'm not sure how I would do this. - # My guess is that it would look - # something like _make_cross_face_batches part_batches[tgt_group_num].append( _make_cross_partition_batch( queue, @@ -614,4 +613,7 @@ def make_partition_connection(vol_to_bdry_conns): return disc_conns +# }}} + + # vim: foldmethod=marker diff --git a/test/test_meshmode.py b/test/test_meshmode.py index c29f013ad5deba8c74038827687f8ada8a1e2a30..9fdf1757196d2c560c5283a8d2322fabd2cd83de 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -48,6 +48,8 @@ import logging logger = logging.getLogger(__name__) +# {{{ partition_interpolation + def test_partition_interpolation(ctx_getter): cl_ctx = ctx_getter() order = 4 @@ -56,7 +58,11 @@ def test_partition_interpolation(ctx_getter): dim = 2 num_parts = 7 from meshmode.mesh.generation import generate_warped_rect_mesh - mesh = generate_warped_rect_mesh(dim, order=order, n=n) + mesh1 = generate_warped_rect_mesh(dim, order=order, n=n) + mesh2 = generate_warped_rect_mesh(dim, order=order, n=n) + + from meshmode.mesh.processing import merge_disjoint_meshes + mesh = merge_disjoint_meshes([mesh1, mesh2]) adjacency_list = np.zeros((mesh.nelements,), dtype=set) for elem in range(mesh.nelements): @@ -73,6 +79,9 @@ def test_partition_interpolation(ctx_getter): part_meshes = [ partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)] + # Hack, I get InterPartitionAdj here instead of from vol_discrs. + adj_parts = [part_meshes[i].interpart_adj_groups for i in range(num_parts)] + from meshmode.discretization import Discretization vol_discrs = [Discretization(cl_ctx, part_meshes[i], group_factory) for i in range(num_parts)] @@ -82,12 +91,14 @@ def test_partition_interpolation(ctx_getter): FRESTR_INTERIOR_FACES) for i in range(num_parts)] from meshmode.discretization.connection import make_partition_connection - connections = make_partition_connection(bdry_connections) + connections = make_partition_connection(bdry_connections, adj_parts) from meshmode.discretization.connection import check_connection for conn in connections: check_connection(conn) +# }}} + # {{{ partition_mesh