From 85531860e803c560dccaa48ce5b6f00cfe99001b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 27 Nov 2014 23:41:04 -0600
Subject: [PATCH] Fix non-contiguous reshape

---
 pyopencl/array.py       | 98 +++++++++++++++++++++++++++++++++++++++--
 pyopencl/ipython_ext.py | 17 ++++---
 test/test_array.py      | 18 ++++++++
 3 files changed, 121 insertions(+), 12 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index d00253a4..71ce3d8c 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -1288,20 +1288,112 @@ class Array(object):
             raise TypeError("unexpected keyword arguments: %s"
                     % kwargs.keys())
 
+        if order not in "CF":
+            raise ValueError("order must be either 'C' or 'F'")
+
         # TODO: add more error-checking, perhaps
+
         if isinstance(shape[0], tuple) or isinstance(shape[0], list):
             shape = tuple(shape[0])
 
+        if any(s < 0 for s in shape):
+            raise NotImplementedError("negative/automatic shapes not supported")
+
         if shape == self.shape:
-            return self
+            return self._new_with_changes(
+                    data=self.base_data, offset=self.offset, shape=shape,
+                    strides=self.strides)
 
-        size = reduce(lambda x, y: x * y, shape, 1)
+        import operator
+        size = reduce(operator.mul, shape, 1)
         if size != self.size:
             raise ValueError("total size of new array must be unchanged")
 
+        # {{{ determine reshaped strides
+
+        # copied and translated from
+        # https://github.com/numpy/numpy/blob/4083883228d61a3b571dec640185b5a5d983bf59/numpy/core/src/multiarray/shape.c  # noqa
+
+        newdims = shape
+        newnd = len(newdims)
+
+        # Remove axes with dimension 1 from the old array. They have no effect
+        # but would need special cases since their strides do not matter.
+
+        olddims = []
+        oldstrides = []
+        for oi in range(len(self.shape)):
+            s = self.shape[oi]
+            if s != 1:
+                olddims.append(s)
+                oldstrides.append(self.strides[oi])
+
+        oldnd = len(olddims)
+
+        newstrides = [-1]*len(newdims)
+
+        # oi to oj and ni to nj give the axis ranges currently worked with
+        oi = 0
+        oj = 1
+        ni = 0
+        nj = 1
+        while ni < newnd and oi < oldnd:
+            np = newdims[ni]
+            op = olddims[oi]
+
+            while np != op:
+                if np < op:
+                    # Misses trailing 1s, these are handled later
+                    np *= newdims[nj]
+                    nj += 1
+                else:
+                    op *= olddims[oj]
+                    oj += 1
+
+            # Check whether the original axes can be combined
+            for ok in range(oi, oj-1):
+                if order == "F":
+                    if oldstrides[ok+1] != olddims[ok]*oldstrides[ok]:
+                        raise ValueError("cannot reshape without copy")
+                else:
+                    # C order
+                    if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]):
+                        raise ValueError("cannot reshape without copy")
+
+            # Calculate new strides for all axes currently worked with
+            if order == "F":
+                newstrides[ni] = oldstrides[oi]
+                for nk in xrange(ni+1, nj):
+                    newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1]
+            else:
+                # C order
+                newstrides[nj - 1] = oldstrides[oj - 1]
+                for nk in range(nj-1, ni, -1):
+                    newstrides[nk - 1] = newstrides[nk]*newdims[nk]
+
+            ni = nj
+            nj += 1
+
+            oi = oj
+            oj += 1
+
+        # Set strides corresponding to trailing 1s of the new shape.
+        if ni >= 1:
+            last_stride = newstrides[ni - 1]
+        else:
+            last_stride = self.dtype.itemsize
+
+        if order == "F":
+            last_stride *= newdims[ni - 1]
+
+        for nk in range(ni, len(shape)):
+            newstrides[nk] = last_stride
+
+        # }}}
+
         return self._new_with_changes(
                 data=self.base_data, offset=self.offset, shape=shape,
-                strides=_make_strides(self.dtype.itemsize, shape, order))
+                strides=tuple(newstrides))
 
     def ravel(self):
         """Returns flattened array containing the same data."""
diff --git a/pyopencl/ipython_ext.py b/pyopencl/ipython_ext.py
index 81fbcdf8..bedbf77f 100644
--- a/pyopencl/ipython_ext.py
+++ b/pyopencl/ipython_ext.py
@@ -3,6 +3,7 @@ from __future__ import division
 from IPython.core.magic import (magics_class, Magics, cell_magic, line_magic)
 
 import pyopencl as cl
+import sys
 
 
 def _try_to_utf8(text):
@@ -14,8 +15,10 @@ def _try_to_utf8(text):
 @magics_class
 class PyOpenCLMagics(Magics):
     def _run_kernel(self, kernel, options):
-        kernel = _try_to_utf8(kernel)
-        options = _try_to_utf8(options).strip()
+        if sys.version_info < (3,):
+            kernel = _try_to_utf8(kernel)
+            options = _try_to_utf8(options).strip()
+
         try:
             ctx = self.shell.user_ns["cl_ctx"]
         except KeyError:
@@ -34,37 +37,33 @@ class PyOpenCLMagics(Magics):
             raise RuntimeError("unable to locate cl context, which must be "
                     "present in namespace as 'cl_ctx' or 'ctx'")
 
-        prg = cl.Program(ctx, kernel).build(options=options)
+        prg = cl.Program(ctx, kernel).build(options=options.split())
 
         for knl in prg.all_kernels():
             self.shell.user_ns[knl.function_name] = knl
 
-
     @cell_magic
     def cl_kernel(self, line, cell):
         kernel = cell
 
-        opts, args = self.parse_options(line,'o:')
+        opts, args = self.parse_options(line, 'o:')
         build_options = opts.get('o', '')
 
         self._run_kernel(kernel, build_options)
 
-
     def _load_kernel_and_options(self, line):
-        opts, args = self.parse_options(line,'o:f:')
+        opts, args = self.parse_options(line, 'o:f:')
 
         build_options = opts.get('o')
         kernel = self.shell.find_user_code(opts.get('f') or args)
 
         return kernel, build_options
 
-
     @line_magic
     def cl_kernel_from_file(self, line):
         kernel, build_options = self._load_kernel_and_options(line)
         self._run_kernel(kernel, build_options)
 
-
     @line_magic
     def cl_load_edit_kernel(self, line):
         kernel, build_options = self._load_kernel_and_options(line)
diff --git a/test/test_array.py b/test/test_array.py
index eb29b9b6..65b22362 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -719,6 +719,24 @@ def test_view_and_strides(ctx_factory):
         assert (y.get() == X.get()[:3, :5]).all()
 
 
+def test_meshmode_view(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    n = 2
+    result = cl.array.empty(queue, (2, n*6), np.float64)
+
+    def view(z):
+        return z[..., n*3:n*6].reshape(z.shape[:-1] + (n, 3))
+
+    result = result.with_queue(queue)
+    result.fill(0)
+    view(result)[0].fill(1)
+    view(result)[1].fill(1)
+    x = result.get()
+    assert (view(x) == 1).all()
+
+
 def test_event_management(ctx_factory):
     context = ctx_factory()
     queue = cl.CommandQueue(context)
-- 
GitLab