diff --git a/pyopencl/check_concurrency.py b/pyopencl/check_concurrency.py index f900e7399d1bde476056cd5e0cc8eeb5a2d45e53..c13cabb13023b999981665843900b1573d422a20 100644 --- a/pyopencl/check_concurrency.py +++ b/pyopencl/check_concurrency.py @@ -35,6 +35,7 @@ logger = logging.getLogger(__name__) OpRecord = namedtuple("OpRecord", [ "kernel_name", "queue", + "event", ]) @@ -51,25 +52,9 @@ QUEUE_TO_EVENTS = weakref.WeakKeyDictionary() # {{{ helpers -def remove_finished_events(events): - global QUEUE_TO_EVENTS - - for evt in events: - queue = evt.get_info(cl.event_info.COMMAND_QUEUE) - if queue not in QUEUE_TO_EVENTS: - continue - - logger.info('[RM] %s: %s', queue, hash(evt)) - QUEUE_TO_EVENTS[queue].remove(hash(evt)) - - def add_events(queue, events): - global QUEUE_TO_EVENTS - - logger.debug('[ADD] %s: %s', queue, set(hash(evt) for evt in events)) - QUEUE_TO_EVENTS.setdefault(queue, set()).update( - [hash(evt) for evt in events]) - + logger.debug('[ADD] %s: %s', queue, events) + QUEUE_TO_EVENTS.setdefault(queue, weakref.WeakSet()).update(events) # }}} @@ -96,14 +81,6 @@ def wrapper_set_arg(cc, kernel, index, obj): return cc.call('set_arg')(kernel, index, obj) -def wrapper_wait_for_events(cc, events): - """Wraps :func:`pyopencl.wait_for_events`""" - - remove_finished_events(events) - - return cc.call('wait_for_events')(events) - - def wrapper_finish(cc, queue): """Wraps :meth:`pyopencl.CommandQueue.finish`""" @@ -119,13 +96,17 @@ def wrapper_enqueue_nd_range_kernel(cc, """Wraps :func:`pyopencl.enqueue_nd_range_kernel`""" logger.debug('enqueue_nd_range_kernel: %s', kernel.function_name) + evt = cc.call('enqueue_nd_range_kernel')(queue, kernel, + global_size, local_size, global_offset, wait_for, g_times_l) + add_events(queue, [evt]) arg_dict = CURRENT_BUF_ARGS.get(kernel) if arg_dict is not None: - synced_events = set([hash(evt) for evt in wait_for]) \ - | QUEUE_TO_EVENTS.get(queue, set()) - logger.debug("synced events: %s", synced_events) + synced_events = weakref.WeakSet() + if wait_for is not None: + synced_events |= weakref.WeakSet(wait_for) + indices = list(arg_dict.keys()) for index, buf in arg_dict.items(): logger.debug("%s: arg %d" % (kernel.function_name, index)) @@ -134,6 +115,7 @@ def wrapper_enqueue_nd_range_kernel(cc, continue prior_ops = BUFFER_TO_OPS.setdefault(buf, []) + unsynced_events = [] for op in prior_ops: prior_queue = op.queue() if prior_queue is None: @@ -141,33 +123,39 @@ def wrapper_enqueue_nd_range_kernel(cc, if prior_queue.int_ptr == queue.int_ptr: continue - prior_events = QUEUE_TO_EVENTS.get(prior_queue, set()) - unsynced_events = prior_events - synced_events - logger.debug("%s prior events: %s", prior_queue, prior_events) - logger.debug("unsynced events: %s", unsynced_events) - - if unsynced_events: - if cc.show_traceback: - print("Traceback") - traceback.print_stack() - from warnings import warn - - warn("\nEventsNotSynced: argument %d " - "current kernel `%s` previous kernel `%s`\n" - "events `%s` not found in `wait_for` " - "or synced with `queue.finish()` " - "or `cl.wait_for_events()`" % ( - index, kernel.function_name, op.kernel_name, - unsynced_events), - RuntimeWarning, stacklevel=5) + prior_event = op.event() + if prior_event is None: + continue + + prior_queue_events = QUEUE_TO_EVENTS.get( + prior_queue, weakref.WeakSet()) + if prior_event in prior_queue_events \ + and prior_event not in synced_events: + unsynced_events.append(op.kernel_name) + + logger.debug("unsynced events: %s", list(unsynced_events)) + if unsynced_events: + if cc.show_traceback: + print("Traceback") + traceback.print_stack() + + from warnings import warn + warn("\n[%5d] EventsNotSynced: argument `%s` in `%s`\n" + "%7s current kernel `%s` previous kernels %s\n" + "%7s %d events not found in `wait_for` " + "or synced with `queue.finish()` " + "or `cl.wait_for_events()`\n" % ( + cc.concurrency_issues, + index, indices, " ", + kernel.function_name, ", ".join(unsynced_events), " ", + len(unsynced_events)), + RuntimeWarning, stacklevel=5) + cc.concurrency_issues += 1 prior_ops.append(OpRecord( kernel_name=kernel.function_name, - queue=weakref.ref(queue),)) - - evt = cc.call('enqueue_nd_range_kernel')(queue, kernel, - global_size, local_size, global_offset, wait_for, g_times_l) - add_events(queue, [evt]) + queue=weakref.ref(queue), + event=weakref.ref(evt),)) return evt @@ -207,6 +195,7 @@ class ConcurrencyCheck(object): def __enter__(self): self._entered = True + self.concurrency_issues = 0 # allow monkeypatching in generated code self._monkey_patch(cl.invoker, 'add_local_imports') @@ -221,8 +210,8 @@ class ConcurrencyCheck(object): # catch events self._monkey_patch(cl.Event, '__hash__', wrapper=lambda x: x.int_ptr) - self._monkey_patch(cl, 'wait_for_events') - self._monkey_patch(cl.CommandQueue, 'finish') + self._monkey_patch(cl.CommandQueue, 'finish', + wrapper=lambda a: wrapper_finish(self, a)) # catch kernel calls to check concurrency self._monkey_patch(cl, 'enqueue_nd_range_kernel') diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 292aee47cf5bef139594ba3c3bd0ffb4c0ad8e04..f14c9398535f5069790fd058b331a341ae2896ff 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -45,6 +45,25 @@ else: faulthandler.enable() +def with_concurrency_check(func): + def wrapper(*args, **kwargs): + import pyopencl.check_concurrency as cc + with cc.ConcurrencyCheck(): + func(*args, **kwargs) + + import logging + logger = logging.getLogger('pyopencl.check_concurrency') + + formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + + return wrapper + + def _skip_if_pocl(plat, up_to_version, msg='unsupported by pocl'): if plat.vendor == "The pocl project": if up_to_version is None or get_pocl_version(plat) <= up_to_version: @@ -1139,31 +1158,18 @@ def test_threaded_nanny_events(ctx_factory): t2.join() +@with_concurrency_check def test_concurrency_checker(ctx_factory): - import logging - logger = logging.getLogger('pyopencl.check_concurrency') - - formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - - logger.setLevel(logging.DEBUG) - logger.addHandler(handler) - - import pyopencl.check_concurrency as ccheck - with ccheck.ConcurrencyCheck(): - ctx = ctx_factory() - queue1 = cl.CommandQueue(ctx) - queue2 = cl.CommandQueue(ctx) - - arr1 = cl_array.zeros(queue1, (10,), np.float32) - arr2 = cl_array.zeros(queue2, (10,), np.float32) - # del arr1.events[:] - del arr2.events[:] + ctx = ctx_factory() + queue1 = cl.CommandQueue(ctx) + queue2 = cl.CommandQueue(ctx) - arr1 - arr2 + arr1 = cl_array.zeros(queue1, (10,), np.float32) + arr2 = cl_array.zeros(queue2, (10,), np.float32) + # del arr1.events[:] + del arr2.events[:] - print('done') + arr1 - arr2 if __name__ == "__main__":