From a5ce2cfcdc460a4ef8f5248d14a0ae1a93748798 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 26 Oct 2019 21:45:57 -0500 Subject: [PATCH] some more tweaks to concurrency checker. Uses WeakSet to store events so we don't have to worry when they get destroyed. The check for out of sync events is also more strict: we now store the event per (buffer, op) as well as per queue and only check events that are still valid in both and only for the current buffer. --- pyopencl/check_concurrency.py | 99 ++++++++++++++++------------------- test/test_wrapper.py | 50 ++++++++++-------- 2 files changed, 72 insertions(+), 77 deletions(-) diff --git a/pyopencl/check_concurrency.py b/pyopencl/check_concurrency.py index f900e739..c13cabb1 100644 --- a/pyopencl/check_concurrency.py +++ b/pyopencl/check_concurrency.py @@ -35,6 +35,7 @@ logger = logging.getLogger(__name__) OpRecord = namedtuple("OpRecord", [ "kernel_name", "queue", + "event", ]) @@ -51,25 +52,9 @@ QUEUE_TO_EVENTS = weakref.WeakKeyDictionary() # {{{ helpers -def remove_finished_events(events): - global QUEUE_TO_EVENTS - - for evt in events: - queue = evt.get_info(cl.event_info.COMMAND_QUEUE) - if queue not in QUEUE_TO_EVENTS: - continue - - logger.info('[RM] %s: %s', queue, hash(evt)) - QUEUE_TO_EVENTS[queue].remove(hash(evt)) - - def add_events(queue, events): - global QUEUE_TO_EVENTS - - logger.debug('[ADD] %s: %s', queue, set(hash(evt) for evt in events)) - QUEUE_TO_EVENTS.setdefault(queue, set()).update( - [hash(evt) for evt in events]) - + logger.debug('[ADD] %s: %s', queue, events) + QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).update(events) # }}} @@ -96,14 +81,6 @@ def wrapper_set_arg(cc, kernel, index, obj): return cc.call('set_arg')(kernel, index, obj) -def wrapper_wait_for_events(cc, events): - """Wraps :func:`pyopencl.wait_for_events`""" - - remove_finished_events(events) - - return cc.call('wait_for_events')(events) - - def wrapper_finish(cc, queue): """Wraps :meth:`pyopencl.CommandQueue.finish`""" @@ -119,13 +96,17 @@ def wrapper_enqueue_nd_range_kernel(cc, """Wraps :func:`pyopencl.enqueue_nd_range_kernel`""" logger.debug('enqueue_nd_range_kernel: %s', kernel.function_name) + evt = cc.call('enqueue_nd_range_kernel')(queue, kernel, + global_size, local_size, global_offset, wait_for, g_times_l) + add_events(queue, [evt]) arg_dict = CURRENT_BUF_ARGS.get(kernel) if arg_dict is not None: - synced_events = set([hash(evt) for evt in wait_for]) \ - | QUEUE_TO_EVENTS.get(queue, set()) - logger.debug("synced events: %s", synced_events) + synced_events = weakref.WeakSet() + if wait_for is not None: + synced_events |= weakref.WeakSet(wait_for) + indices = list(arg_dict.keys()) for index, buf in arg_dict.items(): logger.debug("%s: arg %d" % (kernel.function_name, index)) @@ -134,6 +115,7 @@ def wrapper_enqueue_nd_range_kernel(cc, continue prior_ops = BUFFER_TO_OPS.setdefault(buf, []) + unsynced_events = [] for op in prior_ops: prior_queue = op.queue() if prior_queue is None: @@ -141,33 +123,39 @@ def wrapper_enqueue_nd_range_kernel(cc, if prior_queue.int_ptr == queue.int_ptr: continue - prior_events = QUEUE_TO_EVENTS.get(prior_queue, set()) - unsynced_events = prior_events - synced_events - logger.debug("%s prior events: %s", prior_queue, prior_events) - logger.debug("unsynced events: %s", unsynced_events) - - if unsynced_events: - if cc.show_traceback: - print("Traceback") - traceback.print_stack() - from warnings import warn - - warn("\nEventsNotSynced: argument %d " - "current kernel `%s` previous kernel `%s`\n" - "events `%s` not found in `wait_for` " - "or synced with `queue.finish()` " - "or `cl.wait_for_events()`" % ( - index, kernel.function_name, op.kernel_name, - unsynced_events), - RuntimeWarning, stacklevel=5) + prior_event = op.event() + if prior_event is None: + continue + + prior_queue_events = QUEUE_TO_EVENTS.get( + prior_queue, weakref.WeakSet()) + if prior_event in prior_queue_events \ + and prior_event not in synced_events: + unsynced_events.append(op.kernel_name) + + logger.debug("unsynced events: %s", list(unsynced_events)) + if unsynced_events: + if cc.show_traceback: + print("Traceback") + traceback.print_stack() + + from warnings import warn + warn("\n[%5d] EventsNotSynced: argument `%s` in `%s`\n" + "%7s current kernel `%s` previous kernels %s\n" + "%7s %d events not found in `wait_for` " + "or synced with `queue.finish()` " + "or `cl.wait_for_events()`\n" % ( + cc.concurrency_issues, + index, indices, " ", + kernel.function_name, ", ".join(unsynced_events), " ", + len(unsynced_events)), + RuntimeWarning, stacklevel=5) + cc.concurrency_issues += 1 prior_ops.append(OpRecord( kernel_name=kernel.function_name, - queue=weakref.ref(queue),)) - - evt = cc.call('enqueue_nd_range_kernel')(queue, kernel, - global_size, local_size, global_offset, wait_for, g_times_l) - add_events(queue, [evt]) + queue=weakref.ref(queue), + event=weakref.ref(evt),)) return evt @@ -207,6 +195,7 @@ class ConcurrencyCheck(object): def __enter__(self): self._entered = True + self.concurrency_issues = 0 # allow monkeypatching in generated code self._monkey_patch(cl.invoker, 'add_local_imports') @@ -221,8 +210,8 @@ class ConcurrencyCheck(object): # catch events self._monkey_patch(cl.Event, '__hash__', wrapper=lambda x: x.int_ptr) - self._monkey_patch(cl, 'wait_for_events') - self._monkey_patch(cl.CommandQueue, 'finish') + self._monkey_patch(cl.CommandQueue, 'finish', + wrapper=lambda a: wrapper_finish(self, a)) # catch kernel calls to check concurrency self._monkey_patch(cl, 'enqueue_nd_range_kernel') diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 292aee47..f14c9398 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -45,6 +45,25 @@ else: faulthandler.enable() +def with_concurrency_check(func): + def wrapper(*args, **kwargs): + import pyopencl.check_concurrency as cc + with cc.ConcurrencyCheck(): + func(*args, **kwargs) + + import logging + logger = logging.getLogger('pyopencl.check_concurrency') + + formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + + return wrapper + + def _skip_if_pocl(plat, up_to_version, msg='unsupported by pocl'): if plat.vendor == "The pocl project": if up_to_version is None or get_pocl_version(plat) <= up_to_version: @@ -1139,31 +1158,18 @@ def test_threaded_nanny_events(ctx_factory): t2.join() +@with_concurrency_check def test_concurrency_checker(ctx_factory): - import logging - logger = logging.getLogger('pyopencl.check_concurrency') - - formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - - logger.setLevel(logging.DEBUG) - logger.addHandler(handler) - - import pyopencl.check_concurrency as ccheck - with ccheck.ConcurrencyCheck(): - ctx = ctx_factory() - queue1 = cl.CommandQueue(ctx) - queue2 = cl.CommandQueue(ctx) - - arr1 = cl_array.zeros(queue1, (10,), np.float32) - arr2 = cl_array.zeros(queue2, (10,), np.float32) - # del arr1.events[:] - del arr2.events[:] + ctx = ctx_factory() + queue1 = cl.CommandQueue(ctx) + queue2 = cl.CommandQueue(ctx) - arr1 - arr2 + arr1 = cl_array.zeros(queue1, (10,), np.float32) + arr2 = cl_array.zeros(queue2, (10,), np.float32) + # del arr1.events[:] + del arr2.events[:] - print('done') + arr1 - arr2 if __name__ == "__main__": -- GitLab