From cb912285e2f4998fdc63b52e571e075e8e076e49 Mon Sep 17 00:00:00 2001 From: Michael Campbell Date: Sat, 17 Apr 2021 11:25:25 -0500 Subject: [PATCH] Undumben slightly to handle arrays of arbitrary shape. --- grudge/op.py | 48 +++++++++++++++--------------------------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 42491322..dcfc6efe 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -491,24 +491,17 @@ def cross_rank_trace_pairs(dcoll, vec, tag=None): data as the `internal`, and `external` components, respectively. Each of the TracePair components are structured like *vec*. - The input field data *vec* may be an array up to shape like (n, m). - *vec* may be a scalar(single) DOFArray, an n-vector of DOFArrays, - or an $n \times m$ object array of DOFArrays. - - Each of n*m components are independently communicated by calling - this routine. Upon entry, *vec* is serialized (if needed), each - component is communicated, then (if needed) the components are - de-serialized back to the structure of *vec*, before returned as - TracePairs for each partition boundary. + The input field data *vec* may be a single DOFArray, or an object + array of DOFArrays. Each of *vec* components are independently + communicated by calling this routine. Upon entry, *vec* is + serialized (if needed), each component is communicated, then + (if needed) the components are de-serialized back to the original + structure of *vec*, before returned as TracePairs for each partition + boundary. """ if isinstance(vec, np.ndarray): - if vec.ndim == 2: - vec_n, vec_m = vec.shape - comm_vec = vec.reshape((vec_n*vec_m,)) - else: - comm_vec = vec - vec_m = 0 - vec_n, = comm_vec.shape + original_shape = vec.shape + comm_vec = vec.flatten() comm_n, = comm_vec.shape result = {} @@ -521,25 +514,14 @@ def cross_rank_trace_pairs(dcoll, vec, tag=None): pb_tpairs = [] for remote_rank in connected_ranks(dcoll): - int_result = {} - ext_result = {} - for ivec in range(vec_n): - if vec.ndim == 2: - int_result[ivec] = ( - make_obj_array([result[remote_rank, ivec*vec_m+i].int - for i in range(vec_m)]) - ) - ext_result[ivec] = ( - make_obj_array([result[remote_rank, ivec*vec_m+i].ext - for i in range(vec_m)]) - ) - else: - int_result[ivec] = result[remote_rank, ivec].int - ext_result[ivec] = result[remote_rank, ivec].ext + int_result = make_obj_array([result[remote_rank, i].int + for i in range(comm_n)]) + ext_result = make_obj_array([result[remote_rank, i].ext + for i in range(comm_n)]) pb_tpairs.append(TracePair( dd=sym.as_dofdesc(sym.DTAG_BOUNDARY(BTAG_PARTITION(remote_rank))), - interior=make_obj_array([int_result[i] for i in range(vec_n)]), - exterior=make_obj_array([ext_result[i] for i in range(vec_n)]))) + interior=int_result.reshape(original_shape), + exterior=ext_result.reshape(original_shape))) return pb_tpairs else: return _cross_rank_trace_pairs_scalar_field(dcoll, vec, tag=tag) -- GitLab