From b1d54cf831b61a5a4f7a0ce8d4d5c726a86f775e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 21 Jun 2014 14:10:09 -0500
Subject: [PATCH] Make boundary restrictor batched

---
 meshmode/discretization/connection.py | 48 +++++++++++++++++++++------
 1 file changed, 38 insertions(+), 10 deletions(-)

diff --git a/meshmode/discretization/connection.py b/meshmode/discretization/connection.py
index 4c251aa..bbdfd21 100644
--- a/meshmode/discretization/connection.py
+++ b/meshmode/discretization/connection.py
@@ -162,7 +162,7 @@ class DiscretizationConnection(object):
     # }}}
 
 
-# {{{ constructor functions
+# {{{ same-mesh constructor
 
 def make_same_mesh_connection(queue, to_discr, from_discr):
     if from_discr.mesh is not to_discr.mesh:
@@ -188,6 +188,10 @@ def make_same_mesh_connection(queue, to_discr, from_discr):
     return DiscretizationConnection(
             from_discr, to_discr, groups)
 
+# }}}
+
+
+# {{{ boundary restriction constructor
 
 def make_boundary_extractor(queue, discr, group_factory):
     """
@@ -236,7 +240,7 @@ def make_boundary_extractor(queue, discr, group_factory):
     for igrp, grp in enumerate(discr.groups):
         mgrp = grp.mesh_el_group
         group_boundary_faces = [
-                (ibface_group, ibface_el, ibface_face)
+                (ibface_el, ibface_face)
                 for ibface_group, ibface_el, ibface_face in boundary_faces
                 if ibface_group == igrp]
 
@@ -265,13 +269,26 @@ def make_boundary_extractor(queue, discr, group_factory):
         grp_face_vertex_indices = mgrp.face_vertex_indices()
         grp_vertex_unit_coordinates = mgrp.vertex_unit_coordinates()
 
-        for ibdry_el, (ibface_group, ibface_el, ibface_face) in enumerate(
-                group_boundary_faces):
+        # batch by face_id
+
+        batch_base = 0
+
+        for face_id in xrange(len(grp_face_vertex_indices)):
+            batch_boundary_el_numbers_in_vol = np.array(
+                    [
+                        ibface_el
+                        for ibface_el, ibface_face in group_boundary_faces
+                        if ibface_face == face_id],
+                    dtype=np.intp)
+
+            new_el_numbers = np.arange(
+                    batch_base,
+                    batch_base + len(batch_boundary_el_numbers_in_vol))
+
+            # {{{ no per-element axes in these computations
 
             # Find boundary vertex indices
-            loc_face_vertices = list(grp_face_vertex_indices[ibface_face])
-            glob_face_vertices = mgrp.vertex_indices[ibface_el, loc_face_vertices]
-            vertex_indices[ibdry_el] = vol_to_bdry_vertices[glob_face_vertices]
+            loc_face_vertices = list(grp_face_vertex_indices[face_id])
 
             # Find unit nodes for boundary element
             face_vertex_unit_coordinates = \
@@ -290,10 +307,21 @@ def make_boundary_extractor(queue, discr, group_factory):
                     vol_basis,
                     face_unit_nodes, mgrp.unit_nodes)
 
-            nodes[:, ibdry_el, :] = np.einsum(
-                    "ij,dj->di",
+            # }}}
+
+            # Find vertex_indices
+            glob_face_vertices = mgrp.vertex_indices[
+                    batch_boundary_el_numbers_in_vol][:, loc_face_vertices]
+            vertex_indices[new_el_numbers] = \
+                    vol_to_bdry_vertices[glob_face_vertices]
+
+            # Find nodes
+            nodes[:, new_el_numbers, :] = np.einsum(
+                    "ij,dej->dei",
                     resampling_mat,
-                    mgrp.nodes[:, ibface_el, :])
+                    mgrp.nodes[:, batch_boundary_el_numbers_in_vol, :])
+
+            batch_base += len(batch_boundary_el_numbers_in_vol)
 
         bdry_mesh_group = SimplexElementGroup(
                 mgrp.order, vertex_indices, nodes, unit_nodes=bdry_unit_nodes)
-- 
GitLab