From 73343d5db3ab615e94a09ddb2a26c3b30e4cd21e Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 26 Oct 2019 21:57:58 -0500 Subject: [PATCH] add decorator --- pyopencl/check_concurrency.py | 28 ++++++++++++++++++---------- test/test_wrapper.py | 20 +------------------- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/pyopencl/check_concurrency.py b/pyopencl/check_concurrency.py index c13cabb1..5c81750e 100644 --- a/pyopencl/check_concurrency.py +++ b/pyopencl/check_concurrency.py @@ -50,15 +50,6 @@ CURRENT_BUF_ARGS = weakref.WeakKeyDictionary() QUEUE_TO_EVENTS = weakref.WeakKeyDictionary() -# {{{ helpers - -def add_events(queue, events): - logger.debug('[ADD] %s: %s', queue, events) - QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).update(events) - -# }}} - - # {{{ wrappers def wrapper_add_local_imports(cc, gen): @@ -98,7 +89,7 @@ def wrapper_enqueue_nd_range_kernel(cc, logger.debug('enqueue_nd_range_kernel: %s', kernel.function_name) evt = cc.call('enqueue_nd_range_kernel')(queue, kernel, global_size, local_size, global_offset, wait_for, g_times_l) - add_events(queue, [evt]) + QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).add(evt) arg_dict = CURRENT_BUF_ARGS.get(kernel) if arg_dict is not None: @@ -164,6 +155,23 @@ def wrapper_enqueue_nd_range_kernel(cc, # {{{ +def with_concurrency_check(func): + def wrapper(func, *args, **kwargs): + with ConcurrencyCheck(): + return func(*args, **kwargs) + + formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + global logger + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + + from pytools import decorator + return decorator.decorator(wrapper, func) + + class ConcurrencyCheck(object): _entered = False diff --git a/test/test_wrapper.py b/test/test_wrapper.py index f14c9398..d798a417 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -35,6 +35,7 @@ import pyopencl.clrandom from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) from pyopencl.characterize import get_pocl_version +from pyopencl.check_concurrency import with_concurrency_check # Are CL implementations crashy? You be the judge. :) try: @@ -45,25 +46,6 @@ else: faulthandler.enable() -def with_concurrency_check(func): - def wrapper(*args, **kwargs): - import pyopencl.check_concurrency as cc - with cc.ConcurrencyCheck(): - func(*args, **kwargs) - - import logging - logger = logging.getLogger('pyopencl.check_concurrency') - - formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - - logger.setLevel(logging.DEBUG) - logger.addHandler(handler) - - return wrapper - - def _skip_if_pocl(plat, up_to_version, msg='unsupported by pocl'): if plat.vendor == "The pocl project": if up_to_version is None or get_pocl_version(plat) <= up_to_version: -- GitLab