From f5826c0da96501d235e08d57324804cb0b6259e1 Mon Sep 17 00:00:00 2001 From: Ellis Date: Mon, 13 Nov 2017 13:34:06 -0600 Subject: [PATCH] Working --- grudge/execution.py | 8 ++++++-- grudge/symbolic/mappers/__init__.py | 11 ++++++++--- grudge/symbolic/operators.py | 13 ++++++++----- grudge/symbolic/primitives.py | 2 +- test/test_mpi_communication.py | 2 +- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/grudge/execution.py b/grudge/execution.py index c84db90c..95f9af0d 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -293,11 +293,15 @@ class ExecutionMapper(mappers.Evaluator, self.discr.volume_discr, group_factory) - raise NotImplementedError("map_opposite_rank_face_swap") + # raise NotImplementedError("map_opposite_rank_face_swap") + + if op.remote_rank not in bdry_comm.connected_parts: + # Perhaps this should be detected earlier + return 0 # FIXME: One rank face swap should swap data between the local rank # and exactly one remote rank - return bdry_comm.remote_to_local_bdry_conns[0]( + return bdry_comm.remote_to_local_bdry_conns[op.remote_rank]( self.queue, self.rec(field_expr)).with_queue(self.queue) def map_opposite_interior_face_swap(self, op, field_expr): diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index c528c01d..9daab6da 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -339,9 +339,14 @@ class DistributedMapper(CSECachingMapperMixin, IdentityMapper): def map_operator_binding(self, expr): if isinstance(expr.op, op.OppositeInteriorFaceSwap): - # FIXME: Add the sum of the rank face swaps over each rank - return (op.OppositeInteriorFaceSwap()(self.rec(expr.field)) - + op.OppositeRankFaceSwap()(self.rec(expr.field))) + result = op.OppositeInteriorFaceSwap()(self.rec(expr.field)) + # FIXME: Maybe narrow this down + from mpi4py import MPI + num_ranks = MPI.COMM_WORLD.Get_size() + connected_ranks = range(num_ranks) + for remote_rank in connected_ranks: + result += op.OppositeRankFaceSwap(remote_rank)(self.rec(expr.field)) + return result else: return IdentityMapper.map_operator_binding(self, expr) diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index 188b37c8..f91a2206 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -380,22 +380,25 @@ class RefInverseMassOperator(RefMassOperatorBase): # {{{ boundary-related operators class OppositeRankFaceSwap(Operator): - def __init__(self, dd_in=None, dd_out=None): + def __init__(self, remote_rank, dd_in=None, dd_out=None): sym = _sym() if dd_in is None: # FIXME: Use BTAG_PARTITION instead dd_in = sym.DOFDesc(sym.FRESTR_INTERIOR_FACES) + # dd_in = sym.DOFDesc(sym.BTAG_PARTITION) if dd_out is None: dd_out = dd_in - if dd_in.domain_tag is not sym.FRESTR_INTERIOR_FACES: - raise ValueError("dd_in must be an interior faces domain") - if dd_out != dd_in: - raise ValueError("dd_out and dd_in must be identical") + # if dd_in.domain_tag is not sym.BTAG_PARTITION: + # raise ValueError("dd_in must be an interior faces domain") + # if dd_out != dd_in: + # raise ValueError("dd_out and dd_in must be identical") super(OppositeRankFaceSwap, self).__init__(dd_in, dd_out) + self.remote_rank = remote_rank + mapper_method = intern("map_opposite_rank_face_swap") diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index 173a1a2d..5827805f 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -181,7 +181,7 @@ class DOFDesc(object): pass elif domain_tag is None: pass - elif domain_tag in [BTAG_ALL, BTAG_REALLY_ALL, BTAG_NONE]: + elif domain_tag in [BTAG_ALL, BTAG_REALLY_ALL, BTAG_NONE, BTAG_PARTITION]: pass elif isinstance(domain_tag, DTAG_BOUNDARY): pass diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 55c364b6..29aab0d9 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -112,7 +112,7 @@ def mpi_communication_entrypoint(): dt_stepper = set_up_rk4("w", dt, fields, rhs) - final_t = 10 + final_t = 1 nsteps = int(final_t/dt) print("dt=%g nsteps=%d" % (dt, nsteps)) -- GitLab