diff --git a/test/test_array.py b/test/test_array.py index e89d71228b2a7c68dce2ae3b01c8b383fae736be..d4ea8cd7c07d650492fd71b44bec9988658b1f5d 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1103,6 +1103,74 @@ def test_squeeze(ctx_factory): #assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel()) +def test_fancy_fill(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + numpy_dest = np.zeros((4,), np.int32) + numpy_idx = np.arange(3, dtype=np.int32) + numpy_src = np.arange(8, 9, dtype=np.int32) + numpy_dest[numpy_idx] = numpy_src + + cl_dest = cl_array.zeros(queue, (4,), np.int32) + cl_idx = cl_array.arange(queue, 3, dtype=np.int32) + cl_src = cl_array.arange(queue, 8, 9, dtype=np.int32) + cl_dest[cl_idx] = cl_src + + assert np.all(numpy_dest == cl_dest.get()) + + +def test_fancy_indexing(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + numpy_dest = np.zeros((4,), np.int32) + numpy_idx = np.arange(3, 0, -1, dtype=np.int32) + numpy_src = np.arange(8, 10, dtype=np.int32) + numpy_dest[numpy_idx] = numpy_src + + cl_dest = cl_array.zeros(queue, (4,), np.int32) + cl_idx = cl_array.arange(queue, 3, 0, -1, dtype=np.int32) + cl_src = cl_array.arange(queue, 8, 10, dtype=np.int32) + cl_dest[cl_idx] = cl_src + + assert np.all(numpy_dest == cl_dest.get()) + + cl_idx[1] = 3 + cl_idx[2] = 2 + + numpy_idx[1] = 3 + numpy_idx[2] = 2 + + numpy_dest[numpy_idx] = numpy_src + cl_dest[cl_idx] = cl_src + + assert np.all(numpy_dest == cl_dest.get()) + + +def test_multi_put(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + cl_arrays = [ + cl_array.arange(queue, 0, 3, dtype=np.float32) + for i in range(1, 10) + ] + idx = cl_array.arange(queue, 0, 6, dtype=np.int32) + out_arrays = [ + cl_array.zeros(queue, (10,), np.float32) + for i in range(9) + ] + + out_compare = [np.zeros((10,), np.float32) for i in range(9)] + for i, ary in enumerate(out_compare): + ary[idx.get()] = np.arange(0, 3, dtype=np.float32) + + cl_array.multi_put(cl_arrays, idx, out=out_arrays) + + assert np.all(np.all(out_compare[i] == cl_arrays[i].get()) for i in range(10)) + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the # tests.