From d18375016671751d039bf3a8ecd37504193e4fd7 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Mon, 9 Jul 2018 23:02:48 -0500
Subject: [PATCH] Update for boxtree changes

---
 sumpy/fmm.py     | 107 +++++++++++++++++++----------------------------
 test/test_fmm.py |   2 +-
 2 files changed, 43 insertions(+), 66 deletions(-)

diff --git a/sumpy/fmm.py b/sumpy/fmm.py
index cd3ed0af..17d52eed 100644
--- a/sumpy/fmm.py
+++ b/sumpy/fmm.py
@@ -150,13 +150,29 @@ class SumpyExpansionWranglerCodeContainer(object):
 _SECONDS_PER_NANOSECOND = 1e-9
 
 
-class TimingFuture(object):
+class UnableToCollectTimingData(UserWarning):
+    pass
+
+
+class SumpyTimingFuture(object):
 
-    def __init__(self, events):
+    def __init__(self, queue, events):
+        self.queue = queue
         self.events = events
 
     @memoize_method
-    def get(self):
+    def result(self):
+        from boxtree.fmm import TimingResult
+
+        if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
+            from warnings import warn
+            warn(
+                    "Profiling was not enabled in the command queue. "
+                    "Timing data will not be collected.",
+                    category=UnableToCollectTimingData,
+                    stacklevel=3)
+            return TimingResult(wall_elapsed=None, process_elapsed=None)
+
         pyopencl.wait_for_events(self.events)
 
         result = 0
@@ -164,21 +180,20 @@ class TimingFuture(object):
             result += (
                     (event.profile.end - event.profile.start)
                     * _SECONDS_PER_NANOSECOND)
-        return result
 
-    def __call__(self):
-        from boxtree.fmm import TimingResult
-        return TimingResult(wall_elapsed=self.get(), process_elapsed=None)
+        return TimingResult(wall_elapsed=result, process_elapsed=None)
+
+    def done(self):
+        return all(
+                event.get_info(cl.event_info.COMMAND_EXECUTION_STATUS)
+                == cl.command_execution_status.COMPLETE
+                for event in self.events)
 
 # }}}
 
 
 # {{{ expansion wrangler
 
-class UnableToCollectTimingData(UserWarning):
-    pass
-
-
 class SumpyExpansionWrangler(object):
     """Implements the :class:`boxtree.fmm.ExpansionWranglerInterface`
     by using :mod:`sumpy` expansions/translations.
@@ -330,28 +345,9 @@ class SumpyExpansionWrangler(object):
 
     # }}}
 
-    def update_timing_data(self, description, timing_data, events):
-        if timing_data is None:
-            return
-
-        if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
-            if not self.issued_timing_data_warning:
-                from warnings import warn
-                warn(
-                        "Profiling was not enabled in the command queue. "
-                        "Timing data will not be collected.",
-                        category=UnableToCollectTimingData,
-                        stacklevel=3)
-
-            self.issued_timing_data_warning = True
-            return
-
-        timing_data.description = description
-        timing_data.callback = TimingFuture(events)
-
     def form_multipoles(self,
             level_start_source_box_nrs, source_boxes,
-            src_weights, timing_data=None):
+            src_weights):
         mpoles = self.multipole_expansion_zeros()
 
         kwargs = self.extra_kwargs.copy()
@@ -383,14 +379,12 @@ class SumpyExpansionWrangler(object):
 
             assert mpoles_res is mpoles_view
 
-        self.update_timing_data("form_multipoles", timing_data, events)
-
-        return mpoles
+        return (mpoles, SumpyTimingFuture(self.queue, events))
 
     def coarsen_multipoles(self,
             level_start_source_parent_box_nrs,
             source_parent_boxes,
-            mpoles, timing_data=None):
+            mpoles):
         tree = self.tree
 
         events = []
@@ -442,12 +436,10 @@ class SumpyExpansionWrangler(object):
         if events:
             mpoles.add_event(events[-1])
 
-        self.update_timing_data("coarsen_multipoles", timing_data, events)
-
-        return mpoles
+        return (mpoles, SumpyTimingFuture(self.queue, events))
 
     def eval_direct(self, target_boxes, source_box_starts,
-            source_box_lists, src_weights, timing_data=None):
+            source_box_lists, src_weights):
         pot = self.output_zeros()
 
         kwargs = self.extra_kwargs.copy()
@@ -471,14 +463,12 @@ class SumpyExpansionWrangler(object):
             assert pot_i is pot_res_i
             pot_i.add_event(evt)
 
-        self.update_timing_data("eval_direct", timing_data, events)
-
-        return pot
+        return (pot, SumpyTimingFuture(self.queue, events))
 
     def multipole_to_local(self,
             level_start_target_box_nrs,
             target_boxes, src_box_starts, src_box_lists,
-            mpole_exps, timing_data=None):
+            mpole_exps):
         local_exps = self.local_expansion_zeros()
 
         events = []
@@ -515,13 +505,10 @@ class SumpyExpansionWrangler(object):
                     **self.kernel_extra_kwargs)
             events.append(evt)
 
-        self.update_timing_data("eval_direct", timing_data, events)
-
-        return local_exps
+        return (local_exps, SumpyTimingFuture(self.queue, events))
 
     def eval_multipoles(self,
-            target_boxes_by_source_level, source_boxes_by_level, mpole_exps,
-            timing_data=None):
+            target_boxes_by_source_level, source_boxes_by_level, mpole_exps):
         pot = self.output_zeros()
 
         kwargs = self.kernel_extra_kwargs.copy()
@@ -568,14 +555,11 @@ class SumpyExpansionWrangler(object):
             for pot_i in pot:
                 pot_i.add_event(events[-1])
 
-        self.update_timing_data("eval_multipoles", timing_data, events)
-
-        return pot
+        return (pot, SumpyTimingFuture(self.queue, events))
 
     def form_locals(self,
             level_start_target_or_target_parent_box_nrs,
-            target_or_target_parent_boxes, starts, lists, src_weights,
-            timing_data=None):
+            target_or_target_parent_boxes, starts, lists, src_weights):
         local_exps = self.local_expansion_zeros()
 
         kwargs = self.extra_kwargs.copy()
@@ -612,14 +596,12 @@ class SumpyExpansionWrangler(object):
 
             assert result is target_local_exps_view
 
-        self.update_timing_data("form_locals", timing_data, events)
-
-        return local_exps
+        return (local_exps, SumpyTimingFuture(self.queue, events))
 
     def refine_locals(self,
             level_start_target_or_target_parent_box_nrs,
             target_or_target_parent_boxes,
-            local_exps, timing_data=None):
+            local_exps):
 
         events = []
 
@@ -659,12 +641,9 @@ class SumpyExpansionWrangler(object):
 
         local_exps.add_event(evt)
 
-        self.update_timing_data("refine_locals", timing_data, events)
-
-        return local_exps
+        return (local_exps, SumpyTimingFuture(self.queue, [evt]))
 
-    def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps,
-            timing_data=None):
+    def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps):
         pot = self.output_zeros()
 
         kwargs = self.kernel_extra_kwargs.copy()
@@ -700,9 +679,7 @@ class SumpyExpansionWrangler(object):
             for pot_i, pot_res_i in zip(pot, pot_res):
                 assert pot_i is pot_res_i
 
-        self.update_timing_data("eval_locals", timing_data, events)
-
-        return pot
+        return (pot, SumpyTimingFuture(self.queue, events))
 
     def finalize_potentials(self, potentials):
         return potentials
diff --git a/test/test_fmm.py b/test/test_fmm.py
index 0c4406bf..71e3f044 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -234,7 +234,7 @@ def test_sumpy_fmm(ctx_getter, knl, local_expn_class, mpole_expn_class):
     pconv_verifier()
 
 
-def test_sumpy_fmm_timing_data(ctx_getter):
+def test_sumpy_fmm_timing_data_collection(ctx_getter):
     logging.basicConfig(level=logging.INFO)
 
     ctx = ctx_getter()
-- 
GitLab