From d2109ce2c1432042fa6021ddc45123911fc8d28a Mon Sep 17 00:00:00 2001
From: ellis <eshoag2@illinois.edu>
Date: Sat, 25 Mar 2017 13:00:51 -0500
Subject: [PATCH] Slight progress on make_partition_connection

---
 .../connection/opposite_face.py               | 42 ++++++++++---------
 test/test_meshmode.py                         |  8 ++--
 2 files changed, 27 insertions(+), 23 deletions(-)

diff --git a/meshmode/discretization/connection/opposite_face.py b/meshmode/discretization/connection/opposite_face.py
index 925bfe0d..85696ed0 100644
--- a/meshmode/discretization/connection/opposite_face.py
+++ b/meshmode/discretization/connection/opposite_face.py
@@ -395,8 +395,7 @@ def make_opposite_face_connection(volume_to_bdry_conn):
 
 # {{{ partition_connection
 
-def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
-                            i_tgt_part, i_tgt_grp, i_tgt_elem, i_tgt_face):
+def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp, i_src_elem, i_tgt_part, i_tgt_grp, i_tgt_elem):
     """
     Creates a batch that transfers data to a face from a face of another partition.
 
@@ -411,9 +410,6 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
     :returns: ???
     """
 
-    (i_src_part, i_src_grp, i_src_elem, i_src_face) =\
-                        adj.get_neighbor(i_tgt_elem, i_tgt_face)
-
     src_bdry_discr = vol_to_bdry_conns[i_src_part].to_discr
     tgt_bdry_discr = vol_to_bdry_conns[i_tgt_part].to_discr
 
@@ -424,7 +420,8 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
                 tgt_bdry_discr.nodes().get(queue=queue))
             [:, i_tgt_elem])
 
-    ambient_dim, nelements, n_tgt_unit_nodes = tgt_bdry_nodes.shape
+    ambient_dim, n_tgt_unit_nodes = tgt_bdry_nodes.shape
+    nelements = 1
 
     # (ambient_dim, nelements, nfrom_unit_nodes)
     src_bdry_nodes = (
@@ -557,7 +554,7 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
                 to_element_face=None)
 
 
-def make_partition_connection(vol_to_bdry_conns, adj_parts):
+def make_partition_connection(vol_to_bdry_conns, part_meshes):
     """
     Given a list of boundary restriction connections *volume_to_bdry_conn*,
     return a :class:`DirectDiscretizationConnection` that performs data
@@ -583,24 +580,31 @@ def make_partition_connection(vol_to_bdry_conns, adj_parts):
         with cl.CommandQueue(cl_context) as queue:
 
             bdry_discr = tgt_vol_conn.to_discr
-            tgt_mesh = bdry_discr.mesh
+            #tgt_mesh = bdry_discr.mesh
+            tgt_mesh = part_meshes[i_tgt_part]
             ngroups = len(tgt_mesh.groups)
             part_batches = [[] for _ in range(ngroups)]
-            # Hack, I need to get InterPartitionAdj so I'll receive it directly
-            # as an argument.
-            for tgt_group_num, adj in enumerate(adj_parts[i_tgt_part]):
-                for idx, tgt_elem in enumerate(adj.elements):
-                    tgt_face = adj.element_faces[idx]
-
-                    part_batches[tgt_group_num].append(
+            for i_tgt_grp, adj in enumerate(tgt_mesh.interpart_adj_groups):
+                for idx, i_tgt_elem in enumerate(adj.elements):
+                    i_tgt_face = adj.element_faces[idx]
+                    i_src_part = adj.part_indices[idx]
+                    i_src_elem = adj.neighbors[idx]
+                    i_src_face = adj.neighbor_faces[idx]
+                    #src_mesh = vol_to_bdry_conns[i_src_part].to_discr.mesh
+                    src_mesh = part_meshes[i_src_part]
+                    i_src_grp = src_mesh.find_igrp(i_src_elem)
+                    i_src_elem -= src_mesh.groups[i_src_grp].element_nr_base
+
+                    part_batches[i_tgt_grp].extend(
                             _make_cross_partition_batch(
                                 queue,
                                 vol_to_bdry_conns,
-                                adj,
+                                i_src_part,
+                                i_src_grp,
+                                i_src_elem,
                                 i_tgt_part,
-                                tgt_group_num,
-                                tgt_elem,
-                                tgt_face))
+                                i_tgt_grp,
+                                i_tgt_elem))
 
             # Make one Discr connection for each partition.
             disc_conns.append(DirectDiscretizationConnection(
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 9fdf1757..c8b54d31 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -79,9 +79,6 @@ def test_partition_interpolation(ctx_getter):
     part_meshes = [
         partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
 
-    # Hack, I get InterPartitionAdj here instead of from vol_discrs.
-    adj_parts = [part_meshes[i].interpart_adj_groups for i in range(num_parts)]
-
     from meshmode.discretization import Discretization
     vol_discrs = [Discretization(cl_ctx, part_meshes[i], group_factory)
                     for i in range(num_parts)]
@@ -90,11 +87,14 @@ def test_partition_interpolation(ctx_getter):
     bdry_connections = [make_face_restriction(vol_discrs[i], group_factory,
                             FRESTR_INTERIOR_FACES) for i in range(num_parts)]
 
+    # Hack, I probably shouldn't pass part_meshes directly. This is probably
+    # temporary.
     from meshmode.discretization.connection import make_partition_connection
-    connections = make_partition_connection(bdry_connections, adj_parts)
+    connections = make_partition_connection(bdry_connections, part_meshes)
 
     from meshmode.discretization.connection import check_connection
     for conn in connections:
+        print(conn)
         check_connection(conn)
 
 # }}}
-- 
GitLab