diff --git a/test/test_cost_model.py b/test/test_cost_model.py index 482262ec28582f40adfdde816cb9ec408417e388..7b24f56713dcd338cebbdcef7bb9d59c4e5cbcbf 100644 --- a/test/test_cost_model.py +++ b/test/test_cost_model.py @@ -342,16 +342,18 @@ class ConstantOneQBXExpansionWrangler(ConstantOneExpansionWrangler): center_to_tree_targets = self.geo_data.center_to_tree_targets() qbx_center_to_target_box = self.geo_data.qbx_center_to_target_box() + target_box_to_src_sum = {} + target_box_to_nsrcs = {} + for ictr in global_qbx_centers: tgt_ibox = qbx_center_to_target_box[ictr] - ictr_tgt_start, ictr_tgt_end = center_to_tree_targets.starts[ictr:ictr+2] + isrc_box_start, isrc_box_end = ( + self.trav.neighbor_source_boxes_starts[tgt_ibox:tgt_ibox+2]) - for ictr_tgt in range(ictr_tgt_start, ictr_tgt_end): - ctr_itgt = center_to_tree_targets.lists[ictr_tgt] - - isrc_box_start, isrc_box_end = ( - self.trav.neighbor_source_boxes_starts[tgt_ibox:tgt_ibox+2]) + if tgt_ibox not in target_box_to_src_sum: + nsrcs = 0 + src_sum = 0 for isrc_box in range(isrc_box_start, isrc_box_end): src_ibox = self.trav.neighbor_source_boxes_lists[isrc_box] @@ -360,8 +362,22 @@ class ConstantOneQBXExpansionWrangler(ConstantOneExpansionWrangler): isrc_end = (isrc_start + self.tree.box_source_counts_nonchild[src_ibox]) - pot[0][ctr_itgt] += sum(src_weights[isrc_start:isrc_end]) - ops += isrc_end - isrc_start + src_sum += sum(src_weights[isrc_start:isrc_end]) + nsrcs += isrc_end - isrc_start + + target_box_to_src_sum[tgt_ibox] = src_sum + target_box_to_nsrcs[tgt_ibox] = nsrcs + + src_sum = target_box_to_src_sum[tgt_ibox] + nsrcs = target_box_to_nsrcs[tgt_ibox] + + ictr_tgt_start, ictr_tgt_end = center_to_tree_targets.starts[ictr:ictr+2] + + for ictr_tgt in range(ictr_tgt_start, ictr_tgt_end): + ctr_itgt = center_to_tree_targets.lists[ictr_tgt] + pot[0][ctr_itgt] = src_sum + + ops += (ictr_tgt_end - ictr_tgt_start) * nsrcs return pot, self.timing_future(ops)