From edae8c6b1e9ed86a4e9d28adaf681e8badab3d48 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 8 Jan 2017 10:55:54 +0800
Subject: [PATCH] Make DiscretizationConnection an interface, introduce
 ChainedDiscretizationConnection

---
 .../discretization/connection/__init__.py     | 83 +++++++++++++++----
 meshmode/discretization/connection/face.py    | 13 +--
 .../connection/opposite_face.py               |  6 +-
 .../discretization/connection/refinement.py   |  5 +-
 .../discretization/connection/same_mesh.py    |  4 +-
 5 files changed, 81 insertions(+), 30 deletions(-)

diff --git a/meshmode/discretization/connection/__init__.py b/meshmode/discretization/connection/__init__.py
index ba158358..c0b622d0 100644
--- a/meshmode/discretization/connection/__init__.py
+++ b/meshmode/discretization/connection/__init__.py
@@ -56,6 +56,7 @@ __all__ = [
 
 __doc__ = """
 .. autoclass:: DiscretizationConnection
+.. autoclass:: DirectDiscretizationConnection
 
 .. autofunction:: make_same_mesh_connection
 
@@ -158,9 +159,9 @@ class DiscretizationConnectionElementGroup(object):
 # {{{ connection class
 
 class DiscretizationConnection(object):
-    """Data supporting an interpolation-like operation that takes in data on
-    one discretization and returns it on another. Implemented applications
-    include:
+    """Abstract interface for transporting a DOF vector from one
+    :class:`meshmode.discretization.Discretization` to another.
+    Possible applications include:
 
     *   upsampling/downsampling on the same mesh
     *   restricition to the boundary
@@ -171,24 +172,14 @@ class DiscretizationConnection(object):
 
     .. attribute:: to_discr
 
-    .. attribute:: groups
-
-        a list of :class:`DiscretizationConnectionElementGroup`
-        instances, with a one-to-one correspondence to the groups in
-        :attr:`to_discr`.
-
     .. attribute:: is_surjective
 
         A :class:`bool` indicating whether every output degree
         of freedom is set by the connection.
 
     .. automethod:: __call__
-
-    .. automethod:: full_resample_matrix
-
     """
-
-    def __init__(self, from_discr, to_discr, groups, is_surjective):
+    def __init__(self, from_discr, to_discr, is_surjective):
         if from_discr.cl_context != to_discr.cl_context:
             raise ValueError("from_discr and to_discr must live in the "
                     "same OpenCL context")
@@ -203,14 +194,72 @@ class DiscretizationConnection(object):
             raise ValueError("from_discr and to_discr must agree on the "
                     "element_id_dtype")
 
-        self.cl_context = from_discr.cl_context
-
         self.from_discr = from_discr
         self.to_discr = to_discr
-        self.groups = groups
 
         self.is_surjective = is_surjective
 
+    def __call__(self, queue, vec):
+        raise NotImplementedError()
+
+
+class ChainedDiscretizationConnection(DiscretizationConnection):
+    """Aggregates multiple :class:`DiscretizationConnection` instances
+    into a single one.
+
+    .. attribute:: connections
+    """
+
+    def __init__(self, connections):
+        if not connections:
+            raise ValueError("connections may not be empty")
+
+        super(DirectDiscretizationConnection, self).__init__(
+                connections[0].from_discr,
+                connections[-1].to_discr,
+                is_surjective=all(
+                    cnx.is_surjective for cnx in connections))
+
+        self.connections = connections
+
+    def __call__(self, queue, vec):
+        for cnx in self.connections:
+            vec = cnx(queue, vec)
+
+        return vec
+
+
+class DirectDiscretizationConnection(DiscretizationConnection):
+    """A concrete :class:`DiscretizationConnection` supported by interpolation
+    data.
+
+    .. attribute:: from_discr
+
+    .. attribute:: to_discr
+
+    .. attribute:: groups
+
+        a list of :class:`DiscretizationConnectionElementGroup`
+        instances, with a one-to-one correspondence to the groups in
+        :attr:`to_discr`.
+
+    .. attribute:: is_surjective
+
+        A :class:`bool` indicating whether every output degree
+        of freedom is set by the connection.
+
+    .. automethod:: __call__
+
+    .. automethod:: full_resample_matrix
+
+    """
+
+    def __init__(self, from_discr, to_discr, groups, is_surjective):
+        super(DirectDiscretizationConnection, self).__init__(
+                from_discr, to_discr, is_surjective)
+
+        self.groups = groups
+
     @memoize_method
     def _resample_matrix(self, to_group_index, ibatch_index):
         import modepy as mp
diff --git a/meshmode/discretization/connection/face.py b/meshmode/discretization/connection/face.py
index 124b4252..88d91d9e 100644
--- a/meshmode/discretization/connection/face.py
+++ b/meshmode/discretization/connection/face.py
@@ -62,7 +62,7 @@ def _build_boundary_connection(queue, vol_discr, bdry_discr, connection_data,
         per_face_groups):
     from meshmode.discretization.connection import (
             InterpolationBatch, DiscretizationConnectionElementGroup,
-            DiscretizationConnection)
+            DirectDiscretizationConnection)
 
     ibdry_grp = 0
     batches = []
@@ -105,7 +105,7 @@ def _build_boundary_connection(queue, vol_discr, bdry_discr, connection_data,
 
     assert ibdry_grp == len(bdry_discr.groups)
 
-    return DiscretizationConnection(
+    return DirectDiscretizationConnection(
             vol_discr, bdry_discr, connection_groups,
             is_surjective=True)
 
@@ -180,10 +180,11 @@ def make_face_restriction(discr, group_factory, boundary_tag,
         each other one-to-one, and an interpolation batch is created
         per face.
 
-    :return: a :class:`meshmode.discretization.connection.DiscretizationConnection`
+    :return: a
+        :class:`meshmode.discretization.connection.DirectDiscretizationConnection`
         representing the new connection. The new boundary discretization can be
         obtained from the
-        :attr:`meshmode.discretization.connection.DiscretizationConnection.to_discr`
+        :attr:`meshmode.discretization.connection.DirectDiscretizationConnection.to_discr`
         attribute of the return value, and the corresponding new boundary mesh
         from that.
 
@@ -415,7 +416,7 @@ def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
                 "same number of groups")
 
     from meshmode.discretization.connection import (
-            DiscretizationConnection,
+            DirectDiscretizationConnection,
             DiscretizationConnectionElementGroup,
             InterpolationBatch)
 
@@ -477,7 +478,7 @@ def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
 
                     i_faces_grp += 1
 
-    return DiscretizationConnection(
+    return DirectDiscretizationConnection(
             faces_connection.to_discr,
             all_faces_discr,
             groups,
diff --git a/meshmode/discretization/connection/opposite_face.py b/meshmode/discretization/connection/opposite_face.py
index 70728522..6ce70b2a 100644
--- a/meshmode/discretization/connection/opposite_face.py
+++ b/meshmode/discretization/connection/opposite_face.py
@@ -334,7 +334,7 @@ def _make_el_lookup_table(queue, connection, igrp):
 
 def make_opposite_face_connection(volume_to_bdry_conn):
     """Given a boundary restriction connection *volume_to_bdry_conn*,
-    return a :class:`DiscretizationConnection` that performs data
+    return a :class:`DirectDiscretizationConnection` that performs data
     exchange across opposite faces.
     """
 
@@ -381,8 +381,8 @@ def make_opposite_face_connection(volume_to_bdry_conn):
                             vbc_tgt_grp_face_batch, src_grp_el_lookup))
 
     from meshmode.discretization.connection import (
-            DiscretizationConnection, DiscretizationConnectionElementGroup)
-    return DiscretizationConnection(
+            DirectDiscretizationConnection, DiscretizationConnectionElementGroup)
+    return DirectDiscretizationConnection(
             from_discr=bdry_discr,
             to_discr=bdry_discr,
             groups=[
diff --git a/meshmode/discretization/connection/refinement.py b/meshmode/discretization/connection/refinement.py
index ff009fce..f487b2bf 100644
--- a/meshmode/discretization/connection/refinement.py
+++ b/meshmode/discretization/connection/refinement.py
@@ -127,7 +127,8 @@ def make_refinement_connection(refiner, coarse_discr, group_factory):
         discretizing the fine mesh.
     """
     from meshmode.discretization.connection import (
-        DiscretizationConnectionElementGroup, DiscretizationConnection)
+        DiscretizationConnectionElementGroup,
+        DirectDiscretizationConnection)
 
     coarse_mesh = refiner.get_previous_mesh()
     fine_mesh = refiner.last_mesh
@@ -157,7 +158,7 @@ def make_refinement_connection(refiner, coarse_discr, group_factory):
 
     logger.info("building refinement connection: done")
 
-    return DiscretizationConnection(
+    return DirectDiscretizationConnection(
         from_discr=coarse_discr,
         to_discr=fine_discr,
         groups=groups,
diff --git a/meshmode/discretization/connection/same_mesh.py b/meshmode/discretization/connection/same_mesh.py
index 3b73e34f..a470e929 100644
--- a/meshmode/discretization/connection/same_mesh.py
+++ b/meshmode/discretization/connection/same_mesh.py
@@ -32,7 +32,7 @@ import pyopencl.array  # noqa
 def make_same_mesh_connection(to_discr, from_discr):
     from meshmode.discretization.connection import (
             InterpolationBatch, DiscretizationConnectionElementGroup,
-            DiscretizationConnection)
+            DirectDiscretizationConnection)
 
     if from_discr.mesh is not to_discr.mesh:
         raise ValueError("from_discr and to_discr must be based on "
@@ -56,7 +56,7 @@ def make_same_mesh_connection(to_discr, from_discr):
             groups.append(
                     DiscretizationConnectionElementGroup([ibatch]))
 
-    return DiscretizationConnection(
+    return DirectDiscretizationConnection(
             from_discr, to_discr, groups,
             is_surjective=True)
 
-- 
GitLab