From 85531860e803c560dccaa48ce5b6f00cfe99001b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 27 Nov 2014 23:41:04 -0600 Subject: [PATCH] Fix non-contiguous reshape --- pyopencl/array.py | 98 +++++++++++++++++++++++++++++++++++++++-- pyopencl/ipython_ext.py | 17 ++++--- test/test_array.py | 18 ++++++++ 3 files changed, 121 insertions(+), 12 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index d00253a4..71ce3d8c 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1288,20 +1288,112 @@ class Array(object): raise TypeError("unexpected keyword arguments: %s" % kwargs.keys()) + if order not in "CF": + raise ValueError("order must be either 'C' or 'F'") + # TODO: add more error-checking, perhaps + if isinstance(shape[0], tuple) or isinstance(shape[0], list): shape = tuple(shape[0]) + if any(s < 0 for s in shape): + raise NotImplementedError("negative/automatic shapes not supported") + if shape == self.shape: - return self + return self._new_with_changes( + data=self.base_data, offset=self.offset, shape=shape, + strides=self.strides) - size = reduce(lambda x, y: x * y, shape, 1) + import operator + size = reduce(operator.mul, shape, 1) if size != self.size: raise ValueError("total size of new array must be unchanged") + # {{{ determine reshaped strides + + # copied and translated from + # https://github.com/numpy/numpy/blob/4083883228d61a3b571dec640185b5a5d983bf59/numpy/core/src/multiarray/shape.c # noqa + + newdims = shape + newnd = len(newdims) + + # Remove axes with dimension 1 from the old array. They have no effect + # but would need special cases since their strides do not matter. + + olddims = [] + oldstrides = [] + for oi in range(len(self.shape)): + s = self.shape[oi] + if s != 1: + olddims.append(s) + oldstrides.append(self.strides[oi]) + + oldnd = len(olddims) + + newstrides = [-1]*len(newdims) + + # oi to oj and ni to nj give the axis ranges currently worked with + oi = 0 + oj = 1 + ni = 0 + nj = 1 + while ni < newnd and oi < oldnd: + np = newdims[ni] + op = olddims[oi] + + while np != op: + if np < op: + # Misses trailing 1s, these are handled later + np *= newdims[nj] + nj += 1 + else: + op *= olddims[oj] + oj += 1 + + # Check whether the original axes can be combined + for ok in range(oi, oj-1): + if order == "F": + if oldstrides[ok+1] != olddims[ok]*oldstrides[ok]: + raise ValueError("cannot reshape without copy") + else: + # C order + if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]): + raise ValueError("cannot reshape without copy") + + # Calculate new strides for all axes currently worked with + if order == "F": + newstrides[ni] = oldstrides[oi] + for nk in xrange(ni+1, nj): + newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1] + else: + # C order + newstrides[nj - 1] = oldstrides[oj - 1] + for nk in range(nj-1, ni, -1): + newstrides[nk - 1] = newstrides[nk]*newdims[nk] + + ni = nj + nj += 1 + + oi = oj + oj += 1 + + # Set strides corresponding to trailing 1s of the new shape. + if ni >= 1: + last_stride = newstrides[ni - 1] + else: + last_stride = self.dtype.itemsize + + if order == "F": + last_stride *= newdims[ni - 1] + + for nk in range(ni, len(shape)): + newstrides[nk] = last_stride + + # }}} + return self._new_with_changes( data=self.base_data, offset=self.offset, shape=shape, - strides=_make_strides(self.dtype.itemsize, shape, order)) + strides=tuple(newstrides)) def ravel(self): """Returns flattened array containing the same data.""" diff --git a/pyopencl/ipython_ext.py b/pyopencl/ipython_ext.py index 81fbcdf8..bedbf77f 100644 --- a/pyopencl/ipython_ext.py +++ b/pyopencl/ipython_ext.py @@ -3,6 +3,7 @@ from __future__ import division from IPython.core.magic import (magics_class, Magics, cell_magic, line_magic) import pyopencl as cl +import sys def _try_to_utf8(text): @@ -14,8 +15,10 @@ def _try_to_utf8(text): @magics_class class PyOpenCLMagics(Magics): def _run_kernel(self, kernel, options): - kernel = _try_to_utf8(kernel) - options = _try_to_utf8(options).strip() + if sys.version_info < (3,): + kernel = _try_to_utf8(kernel) + options = _try_to_utf8(options).strip() + try: ctx = self.shell.user_ns["cl_ctx"] except KeyError: @@ -34,37 +37,33 @@ class PyOpenCLMagics(Magics): raise RuntimeError("unable to locate cl context, which must be " "present in namespace as 'cl_ctx' or 'ctx'") - prg = cl.Program(ctx, kernel).build(options=options) + prg = cl.Program(ctx, kernel).build(options=options.split()) for knl in prg.all_kernels(): self.shell.user_ns[knl.function_name] = knl - @cell_magic def cl_kernel(self, line, cell): kernel = cell - opts, args = self.parse_options(line,'o:') + opts, args = self.parse_options(line, 'o:') build_options = opts.get('o', '') self._run_kernel(kernel, build_options) - def _load_kernel_and_options(self, line): - opts, args = self.parse_options(line,'o:f:') + opts, args = self.parse_options(line, 'o:f:') build_options = opts.get('o') kernel = self.shell.find_user_code(opts.get('f') or args) return kernel, build_options - @line_magic def cl_kernel_from_file(self, line): kernel, build_options = self._load_kernel_and_options(line) self._run_kernel(kernel, build_options) - @line_magic def cl_load_edit_kernel(self, line): kernel, build_options = self._load_kernel_and_options(line) diff --git a/test/test_array.py b/test/test_array.py index eb29b9b6..65b22362 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -719,6 +719,24 @@ def test_view_and_strides(ctx_factory): assert (y.get() == X.get()[:3, :5]).all() +def test_meshmode_view(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + n = 2 + result = cl.array.empty(queue, (2, n*6), np.float64) + + def view(z): + return z[..., n*3:n*6].reshape(z.shape[:-1] + (n, 3)) + + result = result.with_queue(queue) + result.fill(0) + view(result)[0].fill(1) + view(result)[1].fill(1) + x = result.get() + assert (view(x) == 1).all() + + def test_event_management(ctx_factory): context = ctx_factory() queue = cl.CommandQueue(context) -- GitLab