diff --git a/pyopencl/check_concurrency.py b/pyopencl/check_concurrency.py
index c13cabb13023b999981665843900b1573d422a20..5c81750e91c5ca5f3667b12023752c8735114fca 100644
--- a/pyopencl/check_concurrency.py
+++ b/pyopencl/check_concurrency.py
@@ -50,15 +50,6 @@ CURRENT_BUF_ARGS = weakref.WeakKeyDictionary()
 QUEUE_TO_EVENTS = weakref.WeakKeyDictionary()
 
 
-# {{{ helpers
-
-def add_events(queue, events):
-    logger.debug('[ADD] %s: %s', queue, events)
-    QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).update(events)
-
-# }}}
-
-
 # {{{ wrappers
 
 def wrapper_add_local_imports(cc, gen):
@@ -98,7 +89,7 @@ def wrapper_enqueue_nd_range_kernel(cc,
     logger.debug('enqueue_nd_range_kernel: %s', kernel.function_name)
     evt = cc.call('enqueue_nd_range_kernel')(queue, kernel,
             global_size, local_size, global_offset, wait_for, g_times_l)
-    add_events(queue, [evt])
+    QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).add(evt)
 
     arg_dict = CURRENT_BUF_ARGS.get(kernel)
     if arg_dict is not None:
@@ -164,6 +155,23 @@ def wrapper_enqueue_nd_range_kernel(cc,
 
 # {{{
 
+def with_concurrency_check(func):
+    def wrapper(func, *args, **kwargs):
+        with ConcurrencyCheck():
+            return func(*args, **kwargs)
+
+    formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
+    handler = logging.StreamHandler()
+    handler.setFormatter(formatter)
+
+    global logger
+    logger.setLevel(logging.DEBUG)
+    logger.addHandler(handler)
+
+    from pytools import decorator
+    return decorator.decorator(wrapper, func)
+
+
 class ConcurrencyCheck(object):
     _entered = False
 
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index f14c9398535f5069790fd058b331a341ae2896ff..d798a417b07b7b9c72f7fc1c453cabfdb8277d97 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -35,6 +35,7 @@ import pyopencl.clrandom
 from pyopencl.tools import (  # noqa
         pytest_generate_tests_for_pyopencl as pytest_generate_tests)
 from pyopencl.characterize import get_pocl_version
+from pyopencl.check_concurrency import with_concurrency_check
 
 # Are CL implementations crashy? You be the judge. :)
 try:
@@ -45,25 +46,6 @@ else:
     faulthandler.enable()
 
 
-def with_concurrency_check(func):
-    def wrapper(*args, **kwargs):
-        import pyopencl.check_concurrency as cc
-        with cc.ConcurrencyCheck():
-            func(*args, **kwargs)
-
-    import logging
-    logger = logging.getLogger('pyopencl.check_concurrency')
-
-    formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
-    handler = logging.StreamHandler()
-    handler.setFormatter(formatter)
-
-    logger.setLevel(logging.DEBUG)
-    logger.addHandler(handler)
-
-    return wrapper
-
-
 def _skip_if_pocl(plat, up_to_version, msg='unsupported by pocl'):
     if plat.vendor == "The pocl project":
         if up_to_version is None or get_pocl_version(plat) <= up_to_version: