From a169c29f88e725cec5ce2ee88c85af4c54387e86 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 18 Jun 2020 16:19:33 -0500 Subject: [PATCH] bind(): Introduce local_only flag --- grudge/execution.py | 64 +++++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/grudge/execution.py b/grudge/execution.py index c3993435..5d507bea 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -684,8 +684,14 @@ class BoundOperator(object): # {{{ process_sym_operator function -def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, - dumper=lambda name, sym_operator: None): +def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, dumper=None, + local_only=None): + if local_only is None: + local_only = False + + if dumper is None: + def dumper(name, sym_operator): + return orig_sym_operator = sym_operator import grudge.symbolic.mappers as mappers @@ -698,26 +704,26 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, sym_operator = \ mappers.OppositeInteriorFaceSwapUniqueIDAssigner()(sym_operator) - # {{{ broadcast root rank's symn_operator + if not local_only: + # {{{ broadcast root rank's symn_operator - # also make sure all ranks had same orig_sym_operator + # also make sure all ranks had same orig_sym_operator - if discrwb.mpi_communicator is not None: - (mgmt_rank_orig_sym_operator, mgmt_rank_sym_operator) = \ - discrwb.mpi_communicator.bcast( - (orig_sym_operator, sym_operator), - discrwb.get_management_rank_index()) + if discrwb.mpi_communicator is not None: + (mgmt_rank_orig_sym_operator, mgmt_rank_sym_operator) = \ + discrwb.mpi_communicator.bcast( + (orig_sym_operator, sym_operator), + discrwb.get_management_rank_index()) - from pytools.obj_array import is_equal as is_oa_equal - if not is_oa_equal(mgmt_rank_orig_sym_operator, orig_sym_operator): - raise ValueError("rank %d received a different symbolic " - "operator to bind from rank %d" - % (discrwb.mpi_communicator.Get_rank(), - discrwb.get_management_rank_index())) + if not np.array_equal(mgmt_rank_orig_sym_operator, orig_sym_operator): + raise ValueError("rank %d received a different symbolic " + "operator to bind from rank %d" + % (discrwb.mpi_communicator.Get_rank(), + discrwb.get_management_rank_index())) - sym_operator = mgmt_rank_sym_operator + sym_operator = mgmt_rank_sym_operator - # }}} + # }}} if post_bind_mapper is not None: dumper("before-postbind", sym_operator) @@ -753,12 +759,13 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, dumper("before-distributed", sym_operator) - volume_mesh = discrwb.discr_from_dd("vol").mesh - from meshmode.distributed import get_connected_partitions - connected_parts = get_connected_partitions(volume_mesh) + if not local_only: + volume_mesh = discrwb.discr_from_dd("vol").mesh + from meshmode.distributed import get_connected_partitions + connected_parts = get_connected_partitions(volume_mesh) - if connected_parts: - sym_operator = mappers.DistributedMapper(connected_parts)(sym_operator) + if connected_parts: + sym_operator = mappers.DistributedMapper(connected_parts)(sym_operator) dumper("before-imass", sym_operator) sym_operator = mappers.InverseMassContractor()(sym_operator) @@ -777,10 +784,16 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, # }}} -def bind(discr, sym_operator, post_bind_mapper=lambda x: x, +def bind(discr, sym_operator, *, post_bind_mapper=lambda x: x, function_registry=base_function_registry, exec_mapper_factory=ExecutionMapper, - debug_flags=frozenset()): + debug_flags=frozenset(), local_only=None): + """ + :param local_only: If *True*, *sym_operator* should oly be evaluated on the + local part of the mesh. No inter-rank communication will take place. + (However rank boundaries, tagged :class:`~meshmode.mesh.BTAG_PARTITION`, + will not automatically be considered part of the domain boundary.) + """ # from grudge.symbolic.mappers import QuadratureUpsamplerRemover # sym_operator = QuadratureUpsamplerRemover(self.quad_min_degrees)( # sym_operator) @@ -800,7 +813,8 @@ def bind(discr, sym_operator, post_bind_mapper=lambda x: x, discr, sym_operator, post_bind_mapper=post_bind_mapper, - dumper=dump_sym_operator) + dumper=dump_sym_operator, + local_only=local_only) from grudge.symbolic.compiler import OperatorCompiler discr_code, eval_code = OperatorCompiler(discr, function_registry)(sym_operator) -- GitLab