diff --git a/pytools/__init__.py b/pytools/__init__.py index 3d03cfa3522c52d15e3e28d77dee7cc0c260edac..8dcef80b73b6292d60594979677faa3f36767672 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -138,6 +138,11 @@ Helpers for :mod:`numpy` .. autofunction:: reshaped_view +Timing data +----------- + +.. autoclass:: record_time + Log utilities ------------- @@ -2067,6 +2072,104 @@ def reshaped_view(a, newshape): # }}} +# {{{ timing utilities + +class _Timer(object): + + def __init__(self): + self.started = False + self.finished = False + + def start(self): + assert not self.started + self.started = True + + import time + if sys.version_info >= (3, 3): + self.perf_counter_start = time.perf_counter() + self.process_time_start = time.process_time() + + else: + import timeit + self.time_start = timeit.default_timer() + + def stop(self): + assert self.started + assert not self.finished + self.finished = True + + import time + if sys.version_info >= (3, 3): + self.wall_elapsed = time.perf_counter() - self.perf_counter_start + self.process_elapsed = time.process_time() - self.process_time_start + + else: + import timeit + self.wall_elapsed = timeit.default_timer() - self.time_start + self.process_elapsed = None + + +class record_time(object): # noqa: N801 + """A decorator for recording timing data for a function call. + + Timing data is saved to a :class:`dict` passed as an optional keyword + argument. The following entries are written: + + - *description* + - *wall_elapsed* + - *process_elapsed*. + + Timing data is returned in seconds. For Python versions before 3.3, + *process_elapsed* is *None*. + + Example usage:: + + >>> from time import sleep + >>> @record_time("timing_data") + ... def slow_function(n): + ... sleep(n) + ... + >>> timing_result = {} + >>> slow_function(1, timing_data=timing_result) + >>> print(timing_result) + {'description': 'slow_function', 'wall_elapsed': 1.0052917799912393, 'process_elapsed': 0.0001330000000000081} + """ # noqa: E501 + + def __init__(self, arg=None, description=None): + self.arg = arg + self.description = description + + def __call__(self, wrapped): + description = self.description or wrapped.__name__ + + from contextlib import contextmanager + + @contextmanager + def time_process(output): + timer = _Timer() + timer.start() + yield + timer.stop() + output["description"] = description + output["wall_elapsed"] = timer.wall_elapsed + output["process_elapsed"] = timer.process_elapsed + + def wrapper(*args, **kwargs): + output = kwargs.pop(self.arg, None) + if output is None: + return wrapped(*args, **kwargs) + + with time_process(output): + return wrapped(*args, **kwargs) + + from functools import update_wrapper + new_wrapper = update_wrapper(wrapper, wrapped) + + return new_wrapper + +# }}} + + # {{{ log utilities class ProcessLogger(object): @@ -2094,14 +2197,8 @@ class ProcessLogger(object): self.logger.log(self.silent_level, "%s: start", self.description) self.is_done = False - import time - if sys.version_info >= (3, 3): - self.perf_counter_start = time.perf_counter() - self.process_time_start = time.process_time() - - else: - import timeit - self.time_start = timeit.default_timer() + self.timer = _Timer() + self.timer.start() import threading self.late_start_log_thread = threading.Thread(target=self._log_start_if_long) @@ -2124,18 +2221,12 @@ class ProcessLogger(object): sleep_duration) def done(self, extra_msg=None, *extra_fmt_args): - import time - if sys.version_info >= (3, 3): - wall_elapsed = time.perf_counter() - self.perf_counter_start - process_elapsed = time.process_time() - self.process_time_start - - else: - import timeit - wall_elapsed = timeit.default_timer() - self.time_start - process_elapsed = None - + self.timer.stop() self.is_done = True + wall_elapsed = self.timer.wall_elapsed + process_elapsed = self.timer.process_elapsed + completion_level = ( self.noisy_level if wall_elapsed > self.long_threshold_seconds diff --git a/test/test_pytools.py b/test/test_pytools.py index 65514dd927c09b3b7a69e76d309e004cb5098d8b..40e1848f19a4a6b0e23ba3e6bb6a37f0a8ffe1eb 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -223,6 +223,24 @@ def test_reshaped_view(): pytools.reshaped_view(b, -1) +def test_record_time(): + from pytools import record_time + + @record_time(arg="times") + def f(a): + return a + + assert f(1) is 1 + assert f(1, times=None) is 1 + + times = {} + f(1, times=times) + + assert "description" in times + assert "wall_elapsed" in times + assert "process_elapsed" in times + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])