diff --git a/test/test_array.py b/test/test_array.py index ffb0714c870ccf554da98c353f72ef9ed98994d2..ffec2f4d21a48e917a5e2e63c78f86c6dfb4af2f 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -2180,6 +2180,64 @@ def test_dtype_conversions(ctx_factory): # }}} +# {{{ test_svm_mem_pool_with_arrays + +@pytest.mark.parametrize("use_mempool", [False, True]) +def test_arrays_with_svm_allocators(ctx_factory, use_mempool): + context = ctx_factory() + queue = cl.CommandQueue(context) + queue2 = cl.CommandQueue(context) + + from pyopencl.characterize import has_coarse_grain_buffer_svm + has_cg_svm = has_coarse_grain_buffer_svm(queue.device) + + if not has_cg_svm: + pytest.skip("Need coarse-grained SVM support for this test.") + + alloc = cl_tools.SVMAllocator(context, queue=queue) + if use_mempool: + alloc = cl_tools.SVMPool(alloc) + + def alloc2(size): + allocation = alloc(size) + allocation.bind_to_queue(queue2) + return allocation + + a_dev = cl_array.arange(queue, 2000, dtype=np.float32, allocator=alloc) + b_dev = cl_array.to_device(queue, np.arange(2000), allocator=alloc) + 4000 + + assert a_dev.allocator is alloc + assert b_dev.allocator is alloc + + assert a_dev.data._queue == queue + assert b_dev.data._queue == queue + + a_dev2 = cl_array.arange(queue2, 2000, dtype=np.float32, allocator=alloc2) + b_dev2 = cl_array.to_device(queue2, np.arange(2000), allocator=alloc2) + 4000 + + assert a_dev2.allocator is alloc2 + assert b_dev2.allocator is alloc2 + + assert a_dev2.data._queue == queue2 + assert b_dev2.data._queue == queue2 + + np.testing.assert_allclose((a_dev+b_dev).get(), (a_dev2+b_dev2).get()) + + with pytest.warns(cl_array.InconsistentOpenCLQueueWarning): + a_dev2.with_queue(queue) + + # safe to let this proceed to deallocation, since we're not + # operating on the memory + + with pytest.warns(cl_array.InconsistentOpenCLQueueWarning): + cl_array.empty(queue2, 2000, np.float32, allocator=alloc) + + # safe to let this proceed to deallocation, since we're not + # operating on the memory + +# }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) @@ -2187,4 +2245,4 @@ if __name__ == "__main__": from pytest import main main([__file__]) -# vim: filetype=pyopencl:fdm=marker +# vim: fdm=marker