diff --git a/.gitignore b/.gitignore index 172c3b70e4f141914a163a438495119d00b5eb52..a520df17f99f8343ad33639bdee13125d2136d17 100644 --- a/.gitignore +++ b/.gitignore @@ -49,7 +49,10 @@ distribute-*.tar.gz core *.sess _build +__pycache__ +*.o .ipynb_checkpoints +cscope.* # needed by jenkins env .env diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index a8bdea15c238a114c5197ebfbd36b91aa3c78e65..bb7e0fb735dbbe84520a28cc024cb8cdf660199a 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -340,7 +340,20 @@ def _add_functionality(): def platform_repr(self): return "<pyopencl.Platform '%s' at 0x%x>" % (self.name, self.int_ptr) + def platform_get_cl_version(self): + import re + version_string = self.version + match = re.match(r"^OpenCL ([0-9]+)\.([0-9]+) .*$", version_string) + if match is None: + raise RuntimeError("platform %s returned non-conformant " + "platform version string '%s'" % + (self, version_string)) + + return int(match.group(1)), int(match.group(2)) + + Platform.__repr__ = platform_repr + Platform._get_cl_version = platform_get_cl_version # }}} @@ -367,16 +380,7 @@ def _add_functionality(): ", ".join(repr(dev) for dev in self.devices)) def context_get_cl_version(self): - import re - platform = self.devices[0].platform - plat_version_string = platform.version - match = re.match(r"^OpenCL ([0-9]+)\.([0-9]+) .*$", - plat_version_string) - if match is None: - raise RuntimeError("platform %s returned non-conformant " - "platform version string '%s'" % (platform, plat_version_string)) - - return int(match.group(1)), int(match.group(2)) + return self.devices[0].platform._get_cl_version() Context.__repr__ = context_repr from pytools import memoize_method diff --git a/src/wrapper/wrap_cl.hpp b/src/wrapper/wrap_cl.hpp index af524fd9b7bdb3d984352ce10c2fb63227ad1a88..95ebc9b0448870973aa9843b6c17944d7cc14c87 100644 --- a/src/wrapper/wrap_cl.hpp +++ b/src/wrapper/wrap_cl.hpp @@ -4236,6 +4236,13 @@ namespace pyopencl PyArray_Descr *tp_descr; if (PyArray_DescrConverter(dtype.ptr(), &tp_descr) != NPY_SUCCEED) throw py::error_already_set(); + cl_mem_flags mem_flags; + PYOPENCL_CALL_GUARDED(clGetMemObjectInfo, + (mem_obj.data(), CL_MEM_FLAGS, sizeof(mem_flags), &mem_flags, 0)); + if (!(mem_flags & CL_MEM_USE_HOST_PTR)) + throw pyopencl::error("MemoryObject.get_host_array", CL_INVALID_VALUE, + "Only MemoryObject with USE_HOST_PTR " + "is supported."); py::extract<npy_intp> shape_as_int(shape); std::vector<npy_intp> dims; diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 2e3636cb11019e4d152e42b9ba1fccae4ce77790..971ca27d6f959b9adb91843acc5038e35d305505 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -629,10 +629,8 @@ def test_wait_for_events(ctx_factory): cl.wait_for_events([evt1, evt2]) -def test_unload_compiler(ctx_factory): - ctx = ctx_factory() - platform = ctx.devices[0].platform - if (ctx._get_cl_version() < (1, 2) or +def test_unload_compiler(platform): + if (platform._get_cl_version() < (1, 2) or cl.get_cl_header_version() < (1, 2)): from pytest import skip skip("clUnloadPlatformCompiler is only available in OpenCL 1.2") @@ -673,6 +671,94 @@ def test_enqueue_task(ctx_factory): assert la.norm(a[::-1] - b) == 0 +def test_platform_get_devices(platform): + dev_types = [cl.device_type.ACCELERATOR, cl.device_type.ALL, + cl.device_type.CPU, cl.device_type.DEFAULT, cl.device_type.GPU] + if (platform._get_cl_version() >= (1, 2) and + cl.get_cl_header_version() >= (1, 2)): + dev_types.append(cl.device_type.CUSTOM) + for dev_type in dev_types: + devs = platform.get_devices(dev_type) + if dev_type in (cl.device_type.DEFAULT, + cl.device_type.ALL, + getattr(cl.device_type, 'CUSTOM', None)): + continue + for dev in devs: + assert dev.type == dev_type + + +def test_user_event(ctx_factory): + ctx = ctx_factory() + if (ctx._get_cl_version() < (1, 1) and + cl.get_cl_header_version() < (1, 1)): + from pytest import skip + skip("UserEvent is only available in OpenCL 1.1") + + status = {} + + def event_waiter1(e, key): + e.wait() + status[key] = True + + def event_waiter2(e, key): + cl.wait_for_events([e]) + status[key] = True + + from threading import Thread + from time import sleep + evt = cl.UserEvent(ctx) + Thread(target=event_waiter1, args=(evt, 1)).start() + sleep(.05) + if status.get(1, False): + raise RuntimeError('UserEvent triggered before set_status') + evt.set_status(cl.command_execution_status.COMPLETE) + sleep(.05) + if not status.get(1, False): + raise RuntimeError('UserEvent.wait timeout') + assert evt.command_execution_status == cl.command_execution_status.COMPLETE + + evt = cl.UserEvent(ctx) + Thread(target=event_waiter2, args=(evt, 2)).start() + sleep(.05) + if status.get(2, False): + raise RuntimeError('UserEvent triggered before set_status') + evt.set_status(cl.command_execution_status.COMPLETE) + sleep(.05) + if not status.get(2, False): + raise RuntimeError('cl.wait_for_events timeout on UserEvent') + assert evt.command_execution_status == cl.command_execution_status.COMPLETE + + +def test_buffer_get_host_array(ctx_factory): + ctx = ctx_factory() + mf = cl.mem_flags + + host_buf = np.random.rand(25).astype(np.float32) + buf = cl.Buffer(ctx, mf.READ_WRITE | mf.USE_HOST_PTR, hostbuf=host_buf) + host_buf2 = buf.get_host_array(25, np.float32) + assert (host_buf == host_buf2).all() + assert (host_buf.__array_interface__['data'][0] == + host_buf.__array_interface__['data'][0]) + assert host_buf2.base is buf + + buf = cl.Buffer(ctx, mf.READ_WRITE | mf.ALLOC_HOST_PTR, size=100) + try: + host_buf2 = buf.get_host_array(25, np.float32) + assert False, ("MemoryObject.get_host_array should not accept buffer " + "without USE_HOST_PTR") + except cl.LogicError: + pass + + host_buf = np.random.rand(25).astype(np.float32) + buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=host_buf) + try: + host_buf2 = buf.get_host_array(25, np.float32) + assert False, ("MemoryObject.get_host_array should not accept buffer " + "without USE_HOST_PTR") + except cl.LogicError: + pass + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa