From d18375016671751d039bf3a8ecd37504193e4fd7 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Mon, 9 Jul 2018 23:02:48 -0500 Subject: [PATCH] Update for boxtree changes --- sumpy/fmm.py | 107 +++++++++++++++++++---------------------------- test/test_fmm.py | 2 +- 2 files changed, 43 insertions(+), 66 deletions(-) diff --git a/sumpy/fmm.py b/sumpy/fmm.py index cd3ed0af..17d52eed 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -150,13 +150,29 @@ class SumpyExpansionWranglerCodeContainer(object): _SECONDS_PER_NANOSECOND = 1e-9 -class TimingFuture(object): +class UnableToCollectTimingData(UserWarning): + pass + + +class SumpyTimingFuture(object): - def __init__(self, events): + def __init__(self, queue, events): + self.queue = queue self.events = events @memoize_method - def get(self): + def result(self): + from boxtree.fmm import TimingResult + + if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE: + from warnings import warn + warn( + "Profiling was not enabled in the command queue. " + "Timing data will not be collected.", + category=UnableToCollectTimingData, + stacklevel=3) + return TimingResult(wall_elapsed=None, process_elapsed=None) + pyopencl.wait_for_events(self.events) result = 0 @@ -164,21 +180,20 @@ class TimingFuture(object): result += ( (event.profile.end - event.profile.start) * _SECONDS_PER_NANOSECOND) - return result - def __call__(self): - from boxtree.fmm import TimingResult - return TimingResult(wall_elapsed=self.get(), process_elapsed=None) + return TimingResult(wall_elapsed=result, process_elapsed=None) + + def done(self): + return all( + event.get_info(cl.event_info.COMMAND_EXECUTION_STATUS) + == cl.command_execution_status.COMPLETE + for event in self.events) # }}} # {{{ expansion wrangler -class UnableToCollectTimingData(UserWarning): - pass - - class SumpyExpansionWrangler(object): """Implements the :class:`boxtree.fmm.ExpansionWranglerInterface` by using :mod:`sumpy` expansions/translations. @@ -330,28 +345,9 @@ class SumpyExpansionWrangler(object): # }}} - def update_timing_data(self, description, timing_data, events): - if timing_data is None: - return - - if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE: - if not self.issued_timing_data_warning: - from warnings import warn - warn( - "Profiling was not enabled in the command queue. " - "Timing data will not be collected.", - category=UnableToCollectTimingData, - stacklevel=3) - - self.issued_timing_data_warning = True - return - - timing_data.description = description - timing_data.callback = TimingFuture(events) - def form_multipoles(self, level_start_source_box_nrs, source_boxes, - src_weights, timing_data=None): + src_weights): mpoles = self.multipole_expansion_zeros() kwargs = self.extra_kwargs.copy() @@ -383,14 +379,12 @@ class SumpyExpansionWrangler(object): assert mpoles_res is mpoles_view - self.update_timing_data("form_multipoles", timing_data, events) - - return mpoles + return (mpoles, SumpyTimingFuture(self.queue, events)) def coarsen_multipoles(self, level_start_source_parent_box_nrs, source_parent_boxes, - mpoles, timing_data=None): + mpoles): tree = self.tree events = [] @@ -442,12 +436,10 @@ class SumpyExpansionWrangler(object): if events: mpoles.add_event(events[-1]) - self.update_timing_data("coarsen_multipoles", timing_data, events) - - return mpoles + return (mpoles, SumpyTimingFuture(self.queue, events)) def eval_direct(self, target_boxes, source_box_starts, - source_box_lists, src_weights, timing_data=None): + source_box_lists, src_weights): pot = self.output_zeros() kwargs = self.extra_kwargs.copy() @@ -471,14 +463,12 @@ class SumpyExpansionWrangler(object): assert pot_i is pot_res_i pot_i.add_event(evt) - self.update_timing_data("eval_direct", timing_data, events) - - return pot + return (pot, SumpyTimingFuture(self.queue, events)) def multipole_to_local(self, level_start_target_box_nrs, target_boxes, src_box_starts, src_box_lists, - mpole_exps, timing_data=None): + mpole_exps): local_exps = self.local_expansion_zeros() events = [] @@ -515,13 +505,10 @@ class SumpyExpansionWrangler(object): **self.kernel_extra_kwargs) events.append(evt) - self.update_timing_data("eval_direct", timing_data, events) - - return local_exps + return (local_exps, SumpyTimingFuture(self.queue, events)) def eval_multipoles(self, - target_boxes_by_source_level, source_boxes_by_level, mpole_exps, - timing_data=None): + target_boxes_by_source_level, source_boxes_by_level, mpole_exps): pot = self.output_zeros() kwargs = self.kernel_extra_kwargs.copy() @@ -568,14 +555,11 @@ class SumpyExpansionWrangler(object): for pot_i in pot: pot_i.add_event(events[-1]) - self.update_timing_data("eval_multipoles", timing_data, events) - - return pot + return (pot, SumpyTimingFuture(self.queue, events)) def form_locals(self, level_start_target_or_target_parent_box_nrs, - target_or_target_parent_boxes, starts, lists, src_weights, - timing_data=None): + target_or_target_parent_boxes, starts, lists, src_weights): local_exps = self.local_expansion_zeros() kwargs = self.extra_kwargs.copy() @@ -612,14 +596,12 @@ class SumpyExpansionWrangler(object): assert result is target_local_exps_view - self.update_timing_data("form_locals", timing_data, events) - - return local_exps + return (local_exps, SumpyTimingFuture(self.queue, events)) def refine_locals(self, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, - local_exps, timing_data=None): + local_exps): events = [] @@ -659,12 +641,9 @@ class SumpyExpansionWrangler(object): local_exps.add_event(evt) - self.update_timing_data("refine_locals", timing_data, events) - - return local_exps + return (local_exps, SumpyTimingFuture(self.queue, [evt])) - def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps, - timing_data=None): + def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): pot = self.output_zeros() kwargs = self.kernel_extra_kwargs.copy() @@ -700,9 +679,7 @@ class SumpyExpansionWrangler(object): for pot_i, pot_res_i in zip(pot, pot_res): assert pot_i is pot_res_i - self.update_timing_data("eval_locals", timing_data, events) - - return pot + return (pot, SumpyTimingFuture(self.queue, events)) def finalize_potentials(self, potentials): return potentials diff --git a/test/test_fmm.py b/test/test_fmm.py index 0c4406bf..71e3f044 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -234,7 +234,7 @@ def test_sumpy_fmm(ctx_getter, knl, local_expn_class, mpole_expn_class): pconv_verifier() -def test_sumpy_fmm_timing_data(ctx_getter): +def test_sumpy_fmm_timing_data_collection(ctx_getter): logging.basicConfig(level=logging.INFO) ctx = ctx_getter() -- GitLab