diff --git a/pyopencl/check_concurrency.py b/pyopencl/check_concurrency.py index c13cabb13023b999981665843900b1573d422a20..5c81750e91c5ca5f3667b12023752c8735114fca 100644 --- a/pyopencl/check_concurrency.py +++ b/pyopencl/check_concurrency.py @@ -50,15 +50,6 @@ CURRENT_BUF_ARGS = weakref.WeakKeyDictionary() QUEUE_TO_EVENTS = weakref.WeakKeyDictionary() -# {{{ helpers - -def add_events(queue, events): - logger.debug('[ADD] %s: %s', queue, events) - QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).update(events) - -# }}} - - # {{{ wrappers def wrapper_add_local_imports(cc, gen): @@ -98,7 +89,7 @@ def wrapper_enqueue_nd_range_kernel(cc, logger.debug('enqueue_nd_range_kernel: %s', kernel.function_name) evt = cc.call('enqueue_nd_range_kernel')(queue, kernel, global_size, local_size, global_offset, wait_for, g_times_l) - add_events(queue, [evt]) + QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).add(evt) arg_dict = CURRENT_BUF_ARGS.get(kernel) if arg_dict is not None: @@ -164,6 +155,23 @@ def wrapper_enqueue_nd_range_kernel(cc, # {{{ +def with_concurrency_check(func): + def wrapper(func, *args, **kwargs): + with ConcurrencyCheck(): + return func(*args, **kwargs) + + formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + global logger + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + + from pytools import decorator + return decorator.decorator(wrapper, func) + + class ConcurrencyCheck(object): _entered = False diff --git a/test/test_wrapper.py b/test/test_wrapper.py index f14c9398535f5069790fd058b331a341ae2896ff..d798a417b07b7b9c72f7fc1c453cabfdb8277d97 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -35,6 +35,7 @@ import pyopencl.clrandom from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) from pyopencl.characterize import get_pocl_version +from pyopencl.check_concurrency import with_concurrency_check # Are CL implementations crashy? You be the judge. :) try: @@ -45,25 +46,6 @@ else: faulthandler.enable() -def with_concurrency_check(func): - def wrapper(*args, **kwargs): - import pyopencl.check_concurrency as cc - with cc.ConcurrencyCheck(): - func(*args, **kwargs) - - import logging - logger = logging.getLogger('pyopencl.check_concurrency') - - formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - - logger.setLevel(logging.DEBUG) - logger.addHandler(handler) - - return wrapper - - def _skip_if_pocl(plat, up_to_version, msg='unsupported by pocl'): if plat.vendor == "The pocl project": if up_to_version is None or get_pocl_version(plat) <= up_to_version: