From f2d9d6e6c9676932dfc3322b8234c19e0fac8b90 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 15 Apr 2011 20:52:54 -0400
Subject: [PATCH] Store stride information in arrays.

---
 doc/source/array.rst |  15 ++++-
 doc/source/misc.rst  |   2 +
 pyopencl/array.py    | 148 ++++++++++++++++++++++++++++++++-----------
 test/test_array.py   |  14 ++++
 4 files changed, 141 insertions(+), 38 deletions(-)

diff --git a/doc/source/array.rst b/doc/source/array.rst
index f9b2cedb..56ab1995 100644
--- a/doc/source/array.rst
+++ b/doc/source/array.rst
@@ -23,7 +23,7 @@ The :class:`Array` Class
 
     An alias for :class:`pyopencl.tools.CLAllocator`.
 
-.. class:: Array(cqa, shape, dtype, order="C", allocator=None, base=None, data=None, queue=None)
+.. class:: Array(cqa, shape, dtype, order="C", *, allocator=None, base=None, data=None, queue=None)
 
     A :class:`numpy.ndarray` work-alike that stores its data and performs its
     computations on the compute device.  *shape* and *dtype* work exactly as in
@@ -42,9 +42,12 @@ The :class:`Array` Class
     :class:`int` representing the address of the newly allocated memory.
     (See :class:`DefaultAllocator`.)
 
+
     .. versionchanged:: 2011.1
         Renamed *context* to *cqa*, made it general-purpose.
 
+        All arguments beyond *order* should be considered keyword-only.
+
     .. attribute :: data
 
         The :class:`pyopencl.MemoryObject` instance created for the memory that backs
@@ -73,6 +76,16 @@ The :class:`Array` Class
         The size of the entire array in bytes. Computed as :attr:`size` times
         ``dtype.itemsize``.
 
+    .. attribute :: strides
+
+        Tuple of bytes to step in each dimension when traversing an array.
+
+    .. attribute :: flags
+
+        Return an object with attributes `c_contiguous`, `f_contiguous` and `forc`,
+        which may be used to query contiguity properties in analogy to
+        :attr:`numpy.ndarray.flags`.
+
     .. method :: __len__()
 
         Returns the size of the leading dimension of *self*.
diff --git a/doc/source/misc.rst b/doc/source/misc.rst
index 7ae20c5e..1a2b0313 100644
--- a/doc/source/misc.rst
+++ b/doc/source/misc.rst
@@ -95,6 +95,8 @@ Version 2011.1
 * Make construction of :class:`pyopencl.array.Array` more flexible (*cqa* argument.)
 * Add :ref:`memory-pools`.
 * Add vector types, see :class:`pyopencl.array.vec`.
+* Add :attr:`pyopencl.array.Array.strides`, :attr:`pyopencl.array.Array.flags`.
+  Allow the creation of arrys in C and Fortran order.
 
 Version 0.92
 ------------
diff --git a/pyopencl/array.py b/pyopencl/array.py
index 71a2206f..f78cab4b 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -184,6 +184,48 @@ def _should_be_cqa(what):
             "versions 2011.x of PyOpenCL." % (what, what),
             DeprecationWarning, 3)
 
+
+
+
+def _f_contiguous_strides(itemsize, shape):
+    if shape:
+        strides = [itemsize]
+        for s in shape[:-1]:
+            strides.append(strides[-1]*s)
+        return tuple(strides)
+    else:
+        return ()
+
+def _c_contiguous_strides(itemsize, shape):
+    if shape:
+        strides = [itemsize]
+        for s in shape[:0:-1]:
+            strides.append(strides[-1]*s)
+        return tuple(strides[::-1])
+    else:
+        return ()
+
+
+
+
+class _ArrayFlags:
+    def __init__(self, ary):
+        self.array = ary
+
+    @property
+    def f_contiguous(self):
+        return self.array.strides == _f_contiguous_strides(
+                self.array.dtype.itemsize, self.array.shape)
+
+    @property
+    def c_contiguous(self):
+        return self.array.strides == _c_contiguous_strides(
+                self.array.dtype.itemsize, self.array.shape)
+
+    @property
+    def forc(self):
+        return self.f_contiguous or self.c_contiguous
+
 # }}}
 
 # {{{ array class
@@ -197,7 +239,7 @@ class Array(object):
     """
 
     def __init__(self, cqa, shape, dtype, order="C", allocator=None,
-            base=None, data=None, queue=None):
+            base=None, data=None, queue=None, strides=None):
         # {{{ backward compatibility for pre-cqa days
 
         if isinstance(cqa, cl.CommandQueue):
@@ -231,6 +273,9 @@ class Array(object):
 
         # invariant here: allocator, queue set
 
+        # {{{ determine shape and strides
+        dtype = np.dtype(dtype)
+
         try:
             s = 1
             for dim in shape:
@@ -242,13 +287,27 @@ class Array(object):
             s = shape
             shape = (shape,)
 
-        self.queue = queue
+        if strides is None:
+            if order == "F":
+                strides = _f_contiguous_strides(
+                        dtype.itemsize, shape)
+            elif order == "C":
+                strides = _c_contiguous_strides(
+                        dtype.itemsize, shape)
+            else:
+                raise ValueError("invalid order: %s" % order)
+        else:
+            # FIXME: We should possibly perform some plausibility
+            # checking on 'strides' here.
 
+            strides = tuple(strides)
+
+        # }}}
+
+        self.queue = queue
         self.shape = shape
-        self.dtype = np.dtype(dtype)
-        if order not in ["C", "F"]:
-            raise ValueError("order must be either 'C' or 'F'")
-        self.order = order
+        self.dtype = dtype
+        self.strides = strides
 
         self.mem_size = self.size = s
         self.nbytes = self.dtype.itemsize * self.size
@@ -269,6 +328,10 @@ class Array(object):
 
         self.context = self.data.context
 
+    @property
+    def flags(self):
+        return _ArrayFlags(self)
+
     #@memoize_method FIXME: reenable
     def get_sizes(self, queue):
         return splay(queue, self.mem_size)
@@ -276,26 +339,36 @@ class Array(object):
     def set(self, ary, queue=None, async=False):
         assert ary.size == self.size
         assert ary.dtype == self.dtype
-        if self.size:
-            evt = cl.enqueue_write_buffer(queue or self.queue, self.data, ary)
+        assert self.flags.forc
+
+        if not ary.flags.forc:
+            if async:
+                raise RuntimeError("cannot asynchronously set from "
+                        "non-contiguous array")
+
+            ary = ary.copy()
 
-            if not async:
-                evt.wait()
+        if self.size:
+            cl.enqueue_write_buffer(queue or self.queue, self.data, ary, 
+                    is_blocking=not async)
 
     def get(self, queue=None, ary=None, async=False):
         if ary is None:
-            ary = np.empty(self.shape, self.dtype, order=self.order)
+            ary = np.empty(self.shape, self.dtype)
+
+            from numpy.lib.stride_tricks import as_strided
+            ary = as_strided(ary, strides=self.strides)
         else:
             if ary.size != self.size:
                 raise TypeError("'ary' has non-matching type")
             if ary.dtype != self.dtype:
                 raise TypeError("'ary' has non-matching size")
 
-        if self.size:
-            evt = cl.enqueue_read_buffer(queue or self.queue, self.data, ary)
+        assert self.flags.forc, "Array in get() must be contiguous"
 
-            if not async:
-                evt.wait()
+        if self.size:
+            cl.enqueue_read_buffer(queue or self.queue, self.data, ary,
+                    is_blocking=not async)
 
         return ary
 
@@ -388,15 +461,23 @@ class Array(object):
                 dest.context, dest.dtype, src.dtype)
 
     def _new_like_me(self, dtype=None, queue=None):
+        strides = None
         if dtype is None:
             dtype = self.dtype
+        else:
+            if dtype == self.dtype:
+                strides = self.strides
+
         queue = queue or self.queue
         if queue is not None:
-            return self.__class__(queue, self.shape, dtype, allocator=self.allocator)
+            return self.__class__(queue, self.shape, dtype, 
+                    allocator=self.allocator, strides=strides)
         elif self.allocator is not None:
-            return self.__class__(self.allocator, self.shape, dtype)
+            return self.__class__(self.allocator, self.shape, dtype,
+                    strides=strides)
         else:
-            return self.__class__(self.context, self.shape, dtype)
+            return self.__class__(self.context, self.shape, dtype,
+                    strides=strides)
 
     # operators ---------------------------------------------------------------
     def mul_add(self, selffac, other, otherfac, queue=None):
@@ -609,22 +690,15 @@ class Array(object):
 
 # {{{ creation helpers
 
-def _to_device(queue, ary, allocator=None, async=False):
-    if ary.flags.f_contiguous:
-        order = "F"
-    elif ary.flags.c_contiguous:
-        order = "C"
-    else:
-        raise ValueError("to_device only works on C- or Fortran-"
-                "contiguous arrays")
-
-    result = Array(queue, ary.shape, ary.dtype, order, allocator)
-    result.set(ary, async=async)
-    return result
-
 def to_device(*args, **kwargs):
     """Converts a numpy array to a :class:`Array`."""
 
+    def _to_device(queue, ary, allocator=None, async=False):
+        result = Array(queue, ary.shape, ary.dtype, allocator, strides=ary.strides)
+        result.set(ary, async=async)
+        return result
+
+
     if isinstance(args[0], cl.Context):
         from warnings import warn
         warn("Passing a context as first argument is deprecated. "
@@ -640,15 +714,15 @@ def to_device(*args, **kwargs):
 
 empty = Array
 
-def _zeros(queue, shape, dtype, order="C", allocator=None):
-    result = Array(queue, shape, dtype, 
-            order=order, allocator=allocator)
-    result.fill(0)
-    return result
-
 def zeros(*args, **kwargs):
     """Returns an array of the given shape and dtype filled with 0's."""
 
+    def _zeros(queue, shape, dtype, order="C", allocator=None):
+        result = Array(queue, shape, dtype,
+                order=order, allocator=allocator)
+        result.fill(0)
+        return result
+
     if isinstance(args[0], cl.Context):
         from warnings import warn
         warn("Passing a context as first argument is deprecated. "
diff --git a/test/test_array.py b/test/test_array.py
index 636a53b9..2e71bbda 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -589,6 +589,20 @@ def test_scan(ctx_getter):
 
 
 
+@pytools.test.mark_test.opencl
+def test_stride_preservation(ctx_getter):
+    context = ctx_getter()
+    queue = cl.CommandQueue(context)
+
+    A = np.random.rand(3,3)
+    AT = A.T
+    print AT.flags.f_contiguous, AT.flags.c_contiguous
+    AT_GPU = cl_array.to_device(queue, AT)
+    print AT_GPU.flags.f_contiguous, AT_GPU.flags.c_contiguous
+    assert np.allclose(AT_GPU.get(),AT)
+
+
+
 
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
-- 
GitLab