From 561ecd8ccad14bfc068fa8d0cde0b0890d069c27 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Fri, 14 Jan 2022 19:51:10 -0600 Subject: [PATCH] prefer input array's constructor instead of cla.Array --- pyopencl/array.py | 53 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 49a1dbe9..8101bea6 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2621,6 +2621,10 @@ def multi_put(arrays, dest_indices, dest_shape=None, out=None, queue=None, def concatenate(arrays, axis=0, queue=None, allocator=None): """ .. versionadded:: 2013.1 + + .. note:: + + The returned array is of the same type as the first array in the list. """ if not arrays: raise ValueError("need at least one array to concatenate") @@ -2654,7 +2658,17 @@ def concatenate(arrays, axis=0, queue=None, allocator=None): shape = tuple(shape) dtype = np.find_common_type([ary.dtype for ary in arrays], []) - result = empty(queue, shape, dtype, allocator=allocator) + + if __debug__: + import builtins + if builtins.any(type(ary) != type(arrays[0]) # noqa: E721 + for ary in arrays[1:]): + from warnings import warn + warn("Elements of 'arrays' not of the same type, returning " + "an instance of the type of arrays[0]", + stacklevel=2) + + result = arrays[0].__class__(queue, shape, dtype, allocator=allocator) full_slice = (slice(None),) * len(shape) @@ -2690,15 +2704,13 @@ def diff(array, queue=None, allocator=None): queue = queue or array.queue allocator = allocator or array.allocator - result = empty(queue, (n-1,), array.dtype, allocator=allocator) + result = array.__class__(queue, (n-1,), array.dtype, allocator=allocator) event1 = _diff(result, array, queue=queue) result.add_event(event1) return result def hstack(arrays, queue=None): - from pyopencl.array import empty - if len(arrays) == 0: raise ValueError("need at least one array to hstack") @@ -2715,7 +2727,17 @@ def hstack(arrays, queue=None): lead_shape = single_valued(ary.shape[:-1] for ary in arrays) w = _builtin_sum([ary.shape[-1] for ary in arrays]) - result = empty(queue, lead_shape+(w,), arrays[0].dtype) + + if __debug__: + import builtins + if builtins.any(type(ary) != type(arrays[0]) # noqa: E721 + for ary in arrays[1:]): + from warnings import warn + warn("Elements of 'arrays' not of the same type, returning " + "an instance of the type of arrays[0]", + stacklevel=2) + + result = arrays[0].__class__(queue, lead_shape+(w,), arrays[0].dtype) index = 0 for ary in arrays: result[..., index:index+ary.shape[-1]] = ary @@ -2764,11 +2786,22 @@ def stack(arrays, axis=0, queue=None): raise NotImplementedError result_shape = input_shape[:axis] + (len(arrays),) + input_shape[axis:] - result = empty(queue, result_shape, np.result_type(*(ary.dtype - for ary in arrays)), - # TODO: reconsider once arrays support non-contiguous - # assignments - order="C" if axis == 0 else "F") + + if __debug__: + import builtins + if builtins.any(type(ary) != type(arrays[0]) # noqa: E721 + for ary in arrays[1:]): + from warnings import warn + warn("Elements of 'arrays' not of the same type, returning " + "an instance of the type of arrays[0]", + stacklevel=2) + + result = arrays[0].__class__(queue, result_shape, + np.result_type(*(ary.dtype + for ary in arrays)), + # TODO: reconsider once arrays support + # non-contiguous assignments + order="C" if axis == 0 else "F") for i, ary in enumerate(arrays): idx = (slice(None),)*axis + (i,) + (slice(None),)*(input_ndim-axis) result[idx] = ary -- GitLab