diff --git a/pytential/qbx/fmm.py b/pytential/qbx/fmm.py index 46d7ae880384f7dfbe44c21b7aea836ffdd4f6d3..badf630046b239303344e1b1a3664e706a12a0f4 100644 --- a/pytential/qbx/fmm.py +++ b/pytential/qbx/fmm.py @@ -198,10 +198,11 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), @log_process(logger) def form_global_qbx_locals(self, src_weights): local_exps = self.qbx_local_expansion_zeros() + events = [] geo_data = self.geo_data if len(geo_data.global_qbx_centers()) == 0: - return local_exps + return (local_exps, SumpyTimingFuture(self.queue, events)) traversal = geo_data.traversal() @@ -227,25 +228,25 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), **kwargs) + events.append(evt) assert local_exps is result result.add_event(evt) - return (result, SumpyTimingFuture(self.queue, [evt])) + return (result, SumpyTimingFuture(self.queue, events)) @log_process(logger) def translate_box_multipoles_to_qbx_local(self, multipole_exps): qbx_expansions = self.qbx_local_expansion_zeros() + events = [] geo_data = self.geo_data if geo_data.ncenters == 0: - return qbx_expansions + return (qbx_expansions, SumpyTimingFuture(self.queue, events)) traversal = geo_data.traversal() wait_for = multipole_exps.events - events = [] - for isrc_level, ssn in enumerate(traversal.from_sep_smaller_by_level): m2qbxl = self.code.m2qbxl( self.level_orders[isrc_level], @@ -277,7 +278,6 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), **self.kernel_extra_kwargs) events.append(evt) - wait_for = [evt] assert qbx_expansions_res is qbx_expansions @@ -290,14 +290,15 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), qbx_expansions = self.qbx_local_expansion_zeros() geo_data = self.geo_data + events = [] + if geo_data.ncenters == 0: - return qbx_expansions + return (qbx_expansions, SumpyTimingFuture(self.queue, events)) + trav = geo_data.traversal() wait_for = local_exps.events - events = [] - for isrc_level in range(geo_data.tree().nlevels): l2qbxl = self.code.l2qbxl( self.level_orders[isrc_level], @@ -326,7 +327,6 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), **self.kernel_extra_kwargs) events.append(evt) - wait_for = [evt] assert qbx_expansions_res is qbx_expansions @@ -339,8 +339,10 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), pot = self.full_output_zeros() geo_data = self.geo_data + events = [] + if len(geo_data.global_qbx_centers()) == 0: - return pot + return (pot, SumpyTimingFuture(self.queue, events)) ctt = geo_data.center_to_tree_targets() @@ -365,7 +367,7 @@ QBXFMMGeometryData.non_qbx_box_target_lists`), for pot_i, pot_res_i in zip(pot, pot_res): assert pot_i is pot_res_i - return (pot, SumpyTimingFuture(self.queue, [evt])) + return (pot, SumpyTimingFuture(self.queue, events)) # }}}