From a169c29f88e725cec5ce2ee88c85af4c54387e86 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 18 Jun 2020 16:19:33 -0500
Subject: [PATCH] bind(): Introduce local_only flag

---
 grudge/execution.py | 64 +++++++++++++++++++++++++++------------------
 1 file changed, 39 insertions(+), 25 deletions(-)

diff --git a/grudge/execution.py b/grudge/execution.py
index c3993435..5d507bea 100644
--- a/grudge/execution.py
+++ b/grudge/execution.py
@@ -684,8 +684,14 @@ class BoundOperator(object):
 
 # {{{ process_sym_operator function
 
-def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None,
-        dumper=lambda name, sym_operator: None):
+def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, dumper=None,
+        local_only=None):
+    if local_only is None:
+        local_only = False
+
+    if dumper is None:
+        def dumper(name, sym_operator):
+            return
 
     orig_sym_operator = sym_operator
     import grudge.symbolic.mappers as mappers
@@ -698,26 +704,26 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None,
     sym_operator = \
             mappers.OppositeInteriorFaceSwapUniqueIDAssigner()(sym_operator)
 
-    # {{{ broadcast root rank's symn_operator
+    if not local_only:
+        # {{{ broadcast root rank's symn_operator
 
-    # also make sure all ranks had same orig_sym_operator
+        # also make sure all ranks had same orig_sym_operator
 
-    if discrwb.mpi_communicator is not None:
-        (mgmt_rank_orig_sym_operator, mgmt_rank_sym_operator) = \
-                discrwb.mpi_communicator.bcast(
-                    (orig_sym_operator, sym_operator),
-                    discrwb.get_management_rank_index())
+        if discrwb.mpi_communicator is not None:
+            (mgmt_rank_orig_sym_operator, mgmt_rank_sym_operator) = \
+                    discrwb.mpi_communicator.bcast(
+                        (orig_sym_operator, sym_operator),
+                        discrwb.get_management_rank_index())
 
-        from pytools.obj_array import is_equal as is_oa_equal
-        if not is_oa_equal(mgmt_rank_orig_sym_operator, orig_sym_operator):
-            raise ValueError("rank %d received a different symbolic "
-                    "operator to bind from rank %d"
-                    % (discrwb.mpi_communicator.Get_rank(),
-                        discrwb.get_management_rank_index()))
+            if not np.array_equal(mgmt_rank_orig_sym_operator, orig_sym_operator):
+                raise ValueError("rank %d received a different symbolic "
+                        "operator to bind from rank %d"
+                        % (discrwb.mpi_communicator.Get_rank(),
+                            discrwb.get_management_rank_index()))
 
-        sym_operator = mgmt_rank_sym_operator
+            sym_operator = mgmt_rank_sym_operator
 
-    # }}}
+        # }}}
 
     if post_bind_mapper is not None:
         dumper("before-postbind", sym_operator)
@@ -753,12 +759,13 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None,
 
     dumper("before-distributed", sym_operator)
 
-    volume_mesh = discrwb.discr_from_dd("vol").mesh
-    from meshmode.distributed import get_connected_partitions
-    connected_parts = get_connected_partitions(volume_mesh)
+    if not local_only:
+        volume_mesh = discrwb.discr_from_dd("vol").mesh
+        from meshmode.distributed import get_connected_partitions
+        connected_parts = get_connected_partitions(volume_mesh)
 
-    if connected_parts:
-        sym_operator = mappers.DistributedMapper(connected_parts)(sym_operator)
+        if connected_parts:
+            sym_operator = mappers.DistributedMapper(connected_parts)(sym_operator)
 
     dumper("before-imass", sym_operator)
     sym_operator = mappers.InverseMassContractor()(sym_operator)
@@ -777,10 +784,16 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None,
 # }}}
 
 
-def bind(discr, sym_operator, post_bind_mapper=lambda x: x,
+def bind(discr, sym_operator, *, post_bind_mapper=lambda x: x,
         function_registry=base_function_registry,
         exec_mapper_factory=ExecutionMapper,
-        debug_flags=frozenset()):
+        debug_flags=frozenset(), local_only=None):
+    """
+    :param local_only: If *True*, *sym_operator* should oly be evaluated on the
+        local part of the mesh. No inter-rank communication will take place.
+        (However rank boundaries, tagged :class:`~meshmode.mesh.BTAG_PARTITION`,
+        will not automatically be considered part of the domain boundary.)
+    """
     # from grudge.symbolic.mappers import QuadratureUpsamplerRemover
     # sym_operator = QuadratureUpsamplerRemover(self.quad_min_degrees)(
     #         sym_operator)
@@ -800,7 +813,8 @@ def bind(discr, sym_operator, post_bind_mapper=lambda x: x,
             discr,
             sym_operator,
             post_bind_mapper=post_bind_mapper,
-            dumper=dump_sym_operator)
+            dumper=dump_sym_operator,
+            local_only=local_only)
 
     from grudge.symbolic.compiler import OperatorCompiler
     discr_code, eval_code = OperatorCompiler(discr, function_registry)(sym_operator)
-- 
GitLab