diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index ee3a7154c397fca7ec82df00adcb5976d4c78243..b4cd69efdb5b5bf21ca3f2064ef44e1e98fc5650 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -37,6 +37,14 @@ def compiler_output(text): +# {{{ Kernel + +class Kernel(_cl._Kernel): + def __init__(self, prg, name): + _cl._Kernel.__init__(self, prg._get_prg(), name) + +# }}} + # {{{ Program (including caching support) class Program(object): @@ -89,7 +97,7 @@ class Program(object): def __getattr__(self, attr): try: - knl = Kernel(self._get_prg(), attr) + knl = Kernel(self, attr) # Nvidia does not raise errors even for invalid names, # but this will give an error if the kernel is invalid. knl.num_args @@ -261,7 +269,7 @@ def _add_functionality(): (_cl._ImageBase.get_image_info, _cl.image_info), Program: (Program.get_info, _cl.program_info), - _cl.Kernel: + Kernel: (Kernel.get_info, _cl.kernel_info), _cl.Sampler: (Sampler.get_info, _cl.sampler_info), diff --git a/src/wrapper/wrap_cl_part_2.cpp b/src/wrapper/wrap_cl_part_2.cpp index 8fd8a63299e802827a356d105aea08e29c281988..f06cc0a1b2933647928bd371672f58df576eeb04 100644 --- a/src/wrapper/wrap_cl_part_2.cpp +++ b/src/wrapper/wrap_cl_part_2.cpp @@ -239,7 +239,7 @@ void pyopencl_expose_part_2() { typedef kernel cls; - py::class_<cls, boost::noncopyable>("Kernel", + py::class_<cls, boost::noncopyable>("_Kernel", py::init<const program &, std::string const &>()) .DEF_SIMPLE_METHOD(get_info) .DEF_SIMPLE_METHOD(get_work_group_info) diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 1fb3487f0d355d96f1b15d245b8ba966c771f3fb..fbde65d6140f22b0a936eec160068ee35a147113 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -204,8 +204,9 @@ class TestCL: mf = cl.mem_flags a_buf = cl.Buffer(context, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a) + knl = cl.Kernel(prg, "mult") try: - prg.mult(queue, a.shape, None, a_buf, 2, 3) + knl(queue, a.shape, None, a_buf, 2, 3) assert False, "PyOpenCL should not accept bare Python types as arguments" except cl.LogicError: pass