diff --git a/pyopencl/ipython_ext.py b/pyopencl/ipython_ext.py index 2d7d42d7c78106490d556bbd5bfb67b8d80decdf..a61d7be76c1c0499edb349e82820c6e29591aed6 100644 --- a/pyopencl/ipython_ext.py +++ b/pyopencl/ipython_ext.py @@ -1,6 +1,6 @@ from __future__ import division -from IPython.core.magic import (magics_class, Magics, cell_magic) +from IPython.core.magic import (magics_class, Magics, cell_magic, line_magic) import pyopencl as cl @@ -13,8 +13,9 @@ def _try_to_utf8(text): @magics_class class PyOpenCLMagics(Magics): - @cell_magic - def cl_kernel(self, line, cell): + def _run_kernel(self, kernel, options): + kernel = _try_to_utf8(kernel) + options = _try_to_utf8(options).strip() try: ctx = self.shell.user_ns["cl_ctx"] except KeyError: @@ -33,13 +34,49 @@ class PyOpenCLMagics(Magics): raise RuntimeError("unable to locate cl context, which must be " "present in namespace as 'cl_ctx' or 'ctx'") - opts, args = self.parse_options(line,'o:') - build_options = opts.get('o', '') - prg = cl.Program(ctx, _try_to_utf8(cell)).build(options=_try_to_utf8(build_options).strip()) + prg = cl.Program(ctx, kernel).build(options=options) for knl in prg.all_kernels(): self.shell.user_ns[knl.function_name] = knl + @cell_magic + def cl_kernel(self, line, cell): + kernel = cell + + opts, args = self.parse_options(line,'o:') + build_options = opts.get('o', '') + + self._run_kernel(kernel, build_options) + + + def _load_kernel_and_options(self, line): + opts, args = self.parse_options(line,'o:f:') + + build_options = opts.get('o') + kernel = self.shell.find_user_code(opts.get('f') or args) + + return kernel, build_options + + + @line_magic + def cl_run_kernel(self, line): + kernel, build_options = self._load_kernel_and_options(line) + self._run_kernel(kernel, build_options) + + + @line_magic + def cl_load_kernel(self, line): + kernel, build_options = self._load_kernel_and_options(line) + header = "%%cl_kernel" + + if build_options: + header = "%s %s" % (header, build_options) + + content = "%s\n\n%s" % (header, kernel) + + self.shell.set_next_input(content) + + def load_ipython_extension(ip): ip.register_magics(PyOpenCLMagics)