From df070ccd064be9679ed1fe6a135d159868670318 Mon Sep 17 00:00:00 2001
From: ellis <eshoag2@illinois.edu>
Date: Tue, 28 Mar 2017 10:06:39 -0500
Subject: [PATCH] Cleaned up call to _make_cross_partiton_batch

---
 .../connection/opposite_face.py               | 151 +++++++-----------
 1 file changed, 57 insertions(+), 94 deletions(-)

diff --git a/meshmode/discretization/connection/opposite_face.py b/meshmode/discretization/connection/opposite_face.py
index c7ea09c0..6779bbd9 100644
--- a/meshmode/discretization/connection/opposite_face.py
+++ b/meshmode/discretization/connection/opposite_face.py
@@ -395,41 +395,31 @@ def make_opposite_face_connection(volume_to_bdry_conn):
 
 # {{{ partition_connection
 
-def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp, i_src_elem, i_tgt_part, i_tgt_grp, i_tgt_elem):
+def _make_cross_partition_batch(queue, vol_to_bdry_conns, part_meshes,
+                                    i_tgt_part, i_tgt_grp, i_tgt_elem, i_tgt_face,
+                                    i_src_part, i_src_grp, i_src_elem, i_src_face):
     """
     Creates a batch that transfers data to a face from a face of another partition.
 
     :arg queue:
-    :arg vol_to_bdry_conns: A list of :class:`Direct` for each partition.
-    :arg adj: :class:`InterPartitionAdj` of partition `i_tgt_part`.
-    :arg i_tgt_part: The target partition number.
-    :arg i_tgt_grp:
-    :arg i_tgt_elem:
-    :arg i_tgt_face:
+    :arg vol_to_bdry_conns: A list of :class:`DirectDiscretizationConnection`
+                for each partition.
 
     :returns: ???
     """
 
-    src_bdry_discr = vol_to_bdry_conns[i_src_part].to_discr
+    src_mesh = part_meshes[i_src_part]
+    tgt_mesh = part_meshes[i_tgt_part]
+
+    adj = tgt_mesh.interpart_adj_groups[i_tgt_grp][i_src_part]
+
     tgt_bdry_discr = vol_to_bdry_conns[i_tgt_part].to_discr
+    src_bdry_discr = vol_to_bdry_conns[i_src_part].to_discr
 
-    tgt_bdry_nodes = (
-            # FIXME: This should view-then-transfer (but PyOpenCL doesn't do
-            # non-contiguous transfers for now).
-            tgt_bdry_discr.groups[i_tgt_grp].view(
-                tgt_bdry_discr.nodes().get(queue=queue))
-            [:, i_tgt_elem, :])
+    tgt_bdry_nodes = tgt_mesh.groups[i_tgt_grp].nodes[:, i_tgt_elem, :]
+    src_bdry_nodes = src_mesh.groups[i_src_grp].nodes[:, i_src_elem, :]
 
     ambient_dim, n_tgt_unit_nodes = tgt_bdry_nodes.shape
-    nelements = 1
-
-    # (ambient_dim, nelements, nfrom_unit_nodes)
-    src_bdry_nodes = (
-            # FIXME: This should view-then-transfer (but PyOpenCL doesn't do
-            # non-contiguous transfers for now).
-            src_bdry_discr.groups[i_src_grp].view(
-                src_bdry_discr.nodes().get(queue=queue))
-            )
 
     tol = 1e4 * np.finfo(tgt_bdry_nodes.dtype).eps
 
@@ -439,9 +429,8 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp,
     dim = src_grp.dim
 
     initial_guess = np.mean(src_mesh_grp.vertex_unit_coordinates(), axis=0)
-
-    src_unit_nodes = np.empty((dim, nelements, n_tgt_unit_nodes))
-    src_unit_nodes[:] = initial_guess.reshape(-1, 1, 1)
+    src_unit_nodes = np.empty((dim, n_tgt_unit_nodes))
+    src_unit_nodes[:] = initial_guess.reshape(-1, 1)
 
     import modepy as mp
     src_vdm = mp.vandermonde(src_grp.basis(), src_grp.unit_nodes)
@@ -449,29 +438,29 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp,
     src_nfuncs = len(src_grp.basis())
 
     def apply_map(unit_nodes):
-        # unit_nodes: (dim, nelements, nto_unit_nodes)
+        # unit_nodes: (dim, nto_unit_nodes)
         # basis_at_unit_nodes
-        basis_at_unit_nodes = np.empty((src_nfuncs, nelements, n_tgt_unit_nodes))
+        basis_at_unit_nodes = np.empty((src_nfuncs, n_tgt_unit_nodes))
         for i, f in enumerate(src_grp.basis()):
             basis_at_unit_nodes[i] = (
                     f(unit_nodes.reshape(dim, -1))
-                    .reshape(nelements, n_tgt_unit_nodes))
-        intp_coeffs = np.einsum("fj,jet->fet", src_inv_t_vdm, basis_at_unit_nodes)
+                    .reshape(n_tgt_unit_nodes))
+        intp_coeffs = np.einsum("fj,jt->ft", src_inv_t_vdm, basis_at_unit_nodes)
         # If we're interpolating 1, we had better get 1 back.
-        one_deviation = np.abs(np.sum(intp_coeffs, axis=0) - 1)
-        assert (one_deviation < tol).all(), np.max(one_deviation)
-        return np.einsum("fet,aef->aet", intp_coeffs, src_bdry_nodes)
+        #one_deviation = np.abs(np.sum(intp_coeffs, axis=0) - 1)
+        #assert (one_deviation < tol).all(), np.max(one_deviation)
+        return np.einsum("ft,af->at", intp_coeffs, src_bdry_nodes)
 
     def get_map_jacobian(unit_nodes):
-        # unit_nodes: (dim, nelements, nto_unit_nodes)
+        # unit_nodes: (dim, nto_unit_nodes)
         # basis_at_unit_nodes
         dbasis_at_unit_nodes = np.empty(
-                (dim, src_nfuncs, nelements, n_tgt_unit_nodes))
+                (dim, src_nfuncs, n_tgt_unit_nodes))
         for i, df in enumerate(src_grp.grad_basis()):
             df_result = df(unit_nodes.reshape(dim, -1))
             for rst_axis, df_r in enumerate(df_result):
                 dbasis_at_unit_nodes[rst_axis, i] = (
-                        df_r.reshape(nelements, n_tgt_unit_nodes))
+                        df_r.reshape(n_tgt_unit_nodes))
         dintp_coeffs = np.einsum(
                 "fj,rjet->rfet", src_inv_t_vdm, dbasis_at_unit_nodes)
         return np.einsum("rfet,aef->raet", dintp_coeffs, src_bdry_nodes)
@@ -486,6 +475,7 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp,
         # equations and Cramer's rule. If you're looking for high-end
         # numerics, look no further than meshmode.
         if dim == 1:
+            # TODO: Needs testing.
             # A is df.T
             ata = np.einsum("iket,jket->ijet", df, df)
             atb = np.einsum("iket,ket->iet", df, resid)
@@ -505,10 +495,10 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp,
             # This stinks, performance-wise, because it's not vectorized.
             # But we'll only hit it for boundaries of 4+D meshes, in which
             # case... good luck. :)
-            for e in range(nelements):
-                for t in range(n_tgt_unit_nodes):
-                    df_inv_resid[:, e, t], _, _, _ = \
-                            la.lstsq(df[:, :, e, t].T, resid[:, e, t])
+            # TODO: Needs testing.
+            for t in range(n_tgt_unit_nodes):
+                df_inv_resid[:, t], _, _, _ = \
+                        la.lstsq(df[:, :, t].T, resid[:, t])
         src_unit_nodes = src_unit_nodes - df_inv_resid
         max_resid = np.max(np.abs(resid))
         logger.debug("gauss-newton residual: %g" % max_resid)
@@ -524,33 +514,11 @@ def _make_cross_partition_batch(queue, vol_to_bdry_conns, i_src_part, i_src_grp,
     def to_dev(ary):
         return cl.array.to_device(queue, ary, array_queue=None)
 
-    done_elements = np.zeros(nelements, dtype=np.bool)
-
-    # TODO: Still need to figure out what's happening here.
-    while True:
-        todo_elements, = np.where(~done_elements)
-        if not len(todo_elements):
-            return
-        template_unit_nodes = src_unit_nodes[:, todo_elements[0], :]
-        unit_node_dist = np.max(np.max(np.abs(
-                src_unit_nodes[:, todo_elements, :]
-                -
-                template_unit_nodes.reshape(dim, 1, -1)),
-                axis=2), axis=0)
-        close_els = todo_elements[unit_node_dist < tol]
-        done_elements[close_els] = True
-        unit_node_dist = np.max(np.max(np.abs(
-                src_unit_nodes[:, todo_elements, :]
-                -
-                template_unit_nodes.reshape(dim, 1, -1)),
-                axis=2), axis=0)
-
-        from meshmode.discretization.connection import InterpolationBatch
-        yield InterpolationBatch(
+    return InterpolationBatch(
                 from_group_index=i_src_grp,
                 from_element_indices=to_dev(from_bdry_element_indices[close_els]),
                 to_element_indices=to_dev(to_bdry_element_indices[close_els]),
-                result_unit_nodes=template_unit_nodes,
+                result_unit_nodes=src_unit_nodes,
                 to_element_face=None)
 
 
@@ -583,37 +551,32 @@ def make_partition_connection(vol_to_bdry_conns, part_meshes):
             #tgt_mesh = bdry_discr.mesh
             tgt_mesh = part_meshes[i_tgt_part]
             ngroups = len(tgt_mesh.groups)
-            #part_batches = [[] for _ in range(ngroups)]
-            part_batches = []
-            for i_tgt_grp, adj in enumerate(tgt_mesh.interpart_adj_groups):
-                part_batches.append(_make_cross_partition_batches(
-                                        queue,
-                                        vol_to_bdry_conns,
-                                        adj,
-                                        tgt_mesh,
-                                        i_tgt_grp))
-                '''
-                for idx, i_tgt_elem in enumerate(adj.elements):
-                    i_tgt_face = adj.element_faces[idx]
-                    i_src_part = adj.part_indices[idx]
-                    i_src_elem = adj.neighbors[idx]
-                    i_src_face = adj.neighbor_faces[idx]
-                    #src_mesh = vol_to_bdry_conns[i_src_part].to_discr.mesh
+            part_batches = [[] for _ in range(ngroups)]
+            for i_tgt_grp, adj_parts in enumerate(tgt_mesh.interpart_adj_groups):
+                for i_src_part, adj in adj_parts.items():
+
                     src_mesh = part_meshes[i_src_part]
-                    i_src_grp = src_mesh.find_igrp(i_src_elem)
-                    i_src_elem -= src_mesh.groups[i_src_grp].element_nr_base
-
-                    part_batches[i_tgt_grp].extend(
-                            _make_cross_partition_batch(
-                                queue,
-                                vol_to_bdry_conns,
-                                i_src_part,
-                                i_src_grp,
-                                i_src_elem,
-                                i_tgt_part,
-                                i_tgt_grp,
-                                i_tgt_elem))
-                '''
+
+                    i_src_elems = adj.neighbors
+                    i_src_faces = adj.neighbor_faces
+                    i_src_grps = [src_mesh.find_igrp(e) for e in i_src_elems]
+                    for i in range(len(i_src_elems)):
+                        i_src_elems[i] -= src_mesh.groups[i_src_grps[i]].element_nr_base
+
+                    i_tgt_elems = adj.elements
+                    i_tgt_faces = adj.element_faces
+
+                    for idx, i_tgt_elem in enumerate(i_tgt_elems):
+                        i_tgt_face = i_tgt_faces[idx]
+                        i_src_elem = i_src_elems[idx]
+                        i_src_face = i_src_faces[idx]
+                        i_src_grp = i_src_grps[idx]
+
+                        part_batches[i_tgt_grp].append(
+                                _make_cross_partition_batch(queue,
+                                    vol_to_bdry_conns, part_meshes,
+                                    i_tgt_part, i_tgt_grp, i_tgt_elem, i_tgt_face,
+                                    i_src_part, i_src_grp, i_src_elem, i_src_face))
 
             # Make one Discr connection for each partition.
             disc_conns.append(DirectDiscretizationConnection(
-- 
GitLab