diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index a84b5d1197c786ecb4ca3b04bb1ddad3e2b72f5f..abbf29c5b726fd808c5a843335c9b2e7e0eb3292 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -1735,6 +1735,11 @@ class SVMMap: # {{{ enqueue_copy +_IMAGE_MEM_OBJ_TYPES = [mem_object_type.IMAGE2D, mem_object_type.IMAGE3D] +if get_cl_header_version() >= (1, 2): + _IMAGE_MEM_OBJ_TYPES.append(mem_object_type.IMAGE2D_ARRAY) + + def enqueue_copy(queue, dest, src, **kwargs): """Copy from :class:`Image`, :class:`Buffer` or the host to :class:`Image`, :class:`Buffer` or the host. (Note: host-to-host @@ -1903,7 +1908,7 @@ def enqueue_copy(queue, dest, src, **kwargs): return _cl._enqueue_copy_buffer(queue, src, dest, **kwargs) # }}} - elif src.type in [mem_object_type.IMAGE2D, mem_object_type.IMAGE3D]: + elif src.type in _IMAGE_MEM_OBJ_TYPES: return _cl._enqueue_copy_image_to_buffer( queue, src, dest, **kwargs) else: @@ -1931,14 +1936,14 @@ def enqueue_copy(queue, dest, src, **kwargs): # }}} - elif dest.type in [mem_object_type.IMAGE2D, mem_object_type.IMAGE3D]: + elif dest.type in _IMAGE_MEM_OBJ_TYPES: # {{{ ... -> image if isinstance(src, MemoryObjectHolder): if src.type == mem_object_type.BUFFER: return _cl._enqueue_copy_buffer_to_image( queue, src, dest, **kwargs) - elif src.type in [mem_object_type.IMAGE2D, mem_object_type.IMAGE3D]: + elif src.type in _IMAGE_MEM_OBJ_TYPES: return _cl._enqueue_copy_image(queue, src, dest, **kwargs) else: raise ValueError("invalid src mem object type") @@ -2002,7 +2007,7 @@ def enqueue_copy(queue, dest, src, **kwargs): return _cl._enqueue_read_buffer(queue, src, dest, **kwargs) - elif src.type in [mem_object_type.IMAGE2D, mem_object_type.IMAGE3D]: + elif src.type in _IMAGE_MEM_OBJ_TYPES: origin = kwargs.pop("origin") region = kwargs.pop("region") diff --git a/test/test_wrapper.py b/test/test_wrapper.py index c0ecdcaace8e3b313939905783bc0753eb26d06f..a7f1402de6d4986a40bbe78f0e881853f934a670 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -1468,6 +1468,61 @@ def test_capture_call(ctx_factory): # }}} +# {{{ test_enqueue_copy_array + +def test_enqueue_copy_array(ctx_factory): + # https://github.com/inducer/pyopencl/issues/618 + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + if ctx._get_cl_version() < (1, 2) or cl.get_cl_header_version() < (1, 2): + pytest.skip("requires CL 1.2") + + if not queue.device.image_support: + pytest.skip("device has no image support") + + image_format = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT) + flags = cl.mem_flags.READ_ONLY + image = np.ascontiguousarray(np.zeros((128, 128, 4), np.float32)) + image_cl = cl.Image(ctx, flags, image_format, + shape=(image.shape[1], image.shape[0], 1), is_array=True) + cl.enqueue_copy(queue, dest=image, src=image_cl, + origin=(0, 0, 0), region=(image.shape[1], image.shape[0], 1)) + + +def test_enqueue_copy_array_2(ctx_factory): + # https://github.com/inducer/pyopencl/issues/618 + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + if ctx._get_cl_version() < (1, 2) or cl.get_cl_header_version() < (1, 2): + pytest.skip("requires CL 1.2") + + if not queue.device.image_support: + pytest.skip("device has no image support") + + image_format = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT) + image = np.ascontiguousarray(np.zeros((128, 128, 4), np.float32)) + image_shape = (image.shape[1], image.shape[0]) + array_shape = (*image_shape, 1) + cl.Image(ctx, cl.mem_flags.READ_ONLY, + image_format, shape=image_shape) + image_array_cl = cl.Image(ctx, cl.mem_flags.READ_ONLY, + image_format, shape=array_shape, is_array=True) + image2_array_cl = cl.Image(ctx, cl.mem_flags.WRITE_ONLY, + image_format, shape=array_shape, is_array=True) + buffer_cl = cl.Buffer(ctx, cl.mem_flags.WRITE_ONLY, size=image.nbytes) + + cl._cl._enqueue_copy_image( + queue, src=image_array_cl, dest=image2_array_cl, src_origin=(0, 0, 0), + dest_origin=(0, 0, 0), region=array_shape) + cl._cl._enqueue_copy_image_to_buffer( + queue, src=image_array_cl, dest=buffer_cl, offset=0, origin=(0, 0, 0), + region=array_shape) + +# }}} + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa