From 0a559f3c20efd325e916076f7475f5371ec91acb Mon Sep 17 00:00:00 2001
From: ellis <eshoag2@illinois.edu>
Date: Fri, 24 Mar 2017 10:30:23 -0500
Subject: [PATCH] Temporarily pass interpart_adj to make_partition_connection

---
 .../connection/opposite_face.py               | 34 ++++++++++---------
 test/test_meshmode.py                         | 15 ++++++--
 2 files changed, 31 insertions(+), 18 deletions(-)

diff --git a/meshmode/discretization/connection/opposite_face.py b/meshmode/discretization/connection/opposite_face.py
index bfed445c..925bfe0d 100644
--- a/meshmode/discretization/connection/opposite_face.py
+++ b/meshmode/discretization/connection/opposite_face.py
@@ -393,6 +393,8 @@ def make_opposite_face_connection(volume_to_bdry_conn):
 # }}}
 
 
+# {{{ partition_connection
+
 def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
                             i_tgt_part, i_tgt_grp, i_tgt_elem, i_tgt_face):
     """
@@ -555,7 +557,7 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
                 to_element_face=None)
 
 
-def make_partition_connection(vol_to_bdry_conns):
+def make_partition_connection(vol_to_bdry_conns, adj_parts):
     """
     Given a list of boundary restriction connections *volume_to_bdry_conn*,
     return a :class:`DirectDiscretizationConnection` that performs data
@@ -572,27 +574,24 @@ def make_partition_connection(vol_to_bdry_conns):
     from meshmode.discretization.connection import (
             DirectDiscretizationConnection, DiscretizationConnectionElementGroup)
 
-    # My intuition tells me that this should not live inside a for loop.
-    # However, I need to grab a cl_context. I'll assume that each context from
-    # each partition is the same and I'll use the first one.
-    cl_context = vol_to_bdry_conns[0].from_discr.cl_context
-    with cl.CommandQueue(cl_context) as queue:
-        # Create a list of batches. Each batch contains interpolation
-        #   data from one partition to another.
-        for i_tgt_part, tgt_vol_conn in enumerate(vol_to_bdry_conns):
+    # Create a list of batches. Each batch contains interpolation
+    #   data from one partition to another.
+    for i_tgt_part, tgt_vol_conn in enumerate(vol_to_bdry_conns):
+
+        # Is this ok in a for loop?
+        cl_context = tgt_vol_conn.from_discr.cl_context
+        with cl.CommandQueue(cl_context) as queue:
+
             bdry_discr = tgt_vol_conn.to_discr
-            tgt_mesh = tgt_vol_conn.to_discr.mesh
+            tgt_mesh = bdry_discr.mesh
             ngroups = len(tgt_mesh.groups)
             part_batches = [[] for _ in range(ngroups)]
-            for tgt_group_num, adj in enumerate(tgt_mesh.interpart_adj_groups):
+            # Hack, I need to get InterPartitionAdj so I'll receive it directly
+            # as an argument.
+            for tgt_group_num, adj in enumerate(adj_parts[i_tgt_part]):
                 for idx, tgt_elem in enumerate(adj.elements):
                     tgt_face = adj.element_faces[idx]
 
-                    # We need to create a batch using the
-                    # neighboring face, element, and group
-                    # I'm not sure how I would do this.
-                    # My guess is that it would look
-                    # something like _make_cross_face_batches
                     part_batches[tgt_group_num].append(
                             _make_cross_partition_batch(
                                 queue,
@@ -614,4 +613,7 @@ def make_partition_connection(vol_to_bdry_conns):
 
     return disc_conns
 
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index c29f013a..9fdf1757 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -48,6 +48,8 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+# {{{ partition_interpolation
+
 def test_partition_interpolation(ctx_getter):
     cl_ctx = ctx_getter()
     order = 4
@@ -56,7 +58,11 @@ def test_partition_interpolation(ctx_getter):
     dim = 2
     num_parts = 7
     from meshmode.mesh.generation import generate_warped_rect_mesh
-    mesh = generate_warped_rect_mesh(dim, order=order, n=n)
+    mesh1 = generate_warped_rect_mesh(dim, order=order, n=n)
+    mesh2 = generate_warped_rect_mesh(dim, order=order, n=n)
+
+    from meshmode.mesh.processing import merge_disjoint_meshes
+    mesh = merge_disjoint_meshes([mesh1, mesh2])
 
     adjacency_list = np.zeros((mesh.nelements,), dtype=set)
     for elem in range(mesh.nelements):
@@ -73,6 +79,9 @@ def test_partition_interpolation(ctx_getter):
     part_meshes = [
         partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
 
+    # Hack, I get InterPartitionAdj here instead of from vol_discrs.
+    adj_parts = [part_meshes[i].interpart_adj_groups for i in range(num_parts)]
+
     from meshmode.discretization import Discretization
     vol_discrs = [Discretization(cl_ctx, part_meshes[i], group_factory)
                     for i in range(num_parts)]
@@ -82,12 +91,14 @@ def test_partition_interpolation(ctx_getter):
                             FRESTR_INTERIOR_FACES) for i in range(num_parts)]
 
     from meshmode.discretization.connection import make_partition_connection
-    connections = make_partition_connection(bdry_connections)
+    connections = make_partition_connection(bdry_connections, adj_parts)
 
     from meshmode.discretization.connection import check_connection
     for conn in connections:
         check_connection(conn)
 
+# }}}
+
 
 # {{{ partition_mesh
 
-- 
GitLab