From cb381a08e19bdc4bfedefee5fc467fa1ebb719b0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sat, 8 Nov 2014 11:41:40 -0500 Subject: [PATCH] Implement hstack() for arrays --- pyopencl/array.py | 29 +++++++++++++++++++++++++++++ pyopencl/tools.py | 4 ++++ 2 files changed, 33 insertions(+) diff --git a/pyopencl/array.py b/pyopencl/array.py index b43756f9..d00253a4 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -57,6 +57,7 @@ except: def _dtype_is_object(t): return False + # {{{ vector types class vec: @@ -2018,6 +2019,34 @@ def diff(array, queue=None, allocator=None): _diff(result, array, queue=queue) return result + +def hstack(arrays, queue=None): + from pyopencl.array import empty + + if len(arrays) == 0: + return empty(queue, (), dtype=np.float64) + + if queue is None: + for ary in arrays: + if ary.queue is not None: + queue = ary.queue + break + + from pytools import all_equal, single_valued + if not all_equal(len(ary.shape) for ary in arrays): + raise ValueError("arguments must all have the same number of axes") + + lead_shape = single_valued(ary.shape[:-1] for ary in arrays) + + w = _builtin_sum([ary.shape[-1] for ary in arrays]) + result = empty(queue, lead_shape+(w,), arrays[0].dtype) + index = 0 + for ary in arrays: + result[..., index:index+ary.shape[-1]] = ary + index += ary.shape[-1] + + return result + # }}} diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 92bf3a80..32aa8e64 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -933,6 +933,10 @@ class _CLFakeArrayModule: from pyopencl.array import empty return empty(self.queue, shape, dtype, order=order) + def hstack(self, arrays): + from pyopencl.array import hstack + return hstack(arrays, self.queue) + def array_module(a): if isinstance(a, np.ndarray): -- GitLab