diff --git a/grudge/eager.py b/grudge/eager.py index f10afcec37c4480098f1c3904a408d2f437e489b..7d9eb1d884844cfe34f8d3425df8b3e6bc44824e 100644 --- a/grudge/eager.py +++ b/grudge/eager.py @@ -374,7 +374,15 @@ def interior_trace_pair(discrwb, vec): *discrwb*. """ i = discrwb.project("vol", "int_faces", vec) - e = obj_array_vectorize(lambda el: discrwb.opposite_face_connection()(el), i) + + def get_opposite_face(el): + if isinstance(el, Number): + return el + else: + return discrwb.opposite_face_connection()(el) + + e = obj_array_vectorize(get_opposite_face, i) + return TracePair("int_faces", interior=i, exterior=e) @@ -424,9 +432,13 @@ class _RankBoundaryCommunication: def _cross_rank_trace_pairs_scalar_field(discrwb, vec, tag=None): - rbcomms = [_RankBoundaryCommunication(discrwb, remote_rank, vec, tag=tag) - for remote_rank in discrwb.connected_ranks()] - return [rbcomm.finish() for rbcomm in rbcomms] + if isinstance(vec, Number): + return [TracePair(BTAG_PARTITION(remote_rank), interior=vec, exterior=vec) + for remote_rank in discrwb.connected_ranks()] + else: + rbcomms = [_RankBoundaryCommunication(discrwb, remote_rank, vec, tag=tag) + for remote_rank in discrwb.connected_ranks()] + return [rbcomm.finish() for rbcomm in rbcomms] def cross_rank_trace_pairs(discrwb, vec, tag=None):