diff --git a/pyopencl/array.py b/pyopencl/array.py index 4e8b52a130a142daee7cfe7204736a022b4cb201..9bc0739ea5739d4572db590da629d325e636b74a 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -477,8 +477,8 @@ class Array: queue = None else: - raise TypeError("cq may be a queue or a context, not '%s'" - % type(cq)) + raise TypeError( + f"cq may be a queue or a context, not '{type(cq).__name__}'") if allocator is not None: # "is" would be wrong because two Python objects are allowed @@ -2402,7 +2402,7 @@ def take(a, indices, out=None, queue=None, wait_for=None): queue = queue or a.queue if out is None: - out = Array(queue, indices.shape, a.dtype, allocator=a.allocator) + out = type(a)(queue, indices.shape, a.dtype, allocator=a.allocator) assert len(indices.shape) == 1 out.add_event( @@ -2425,9 +2425,11 @@ def multi_take(arrays, indices, out=None, queue=None): vec_count = len(arrays) if out is None: - out = [Array(context, queue, indices.shape, a_dtype, - allocator=a_allocator) - for i in range(vec_count)] + out = [ + type(arrays[i])( + context, queue, indices.shape, a_dtype, + allocator=a_allocator) + for i in range(vec_count)] else: if len(out) != len(arrays): raise ValueError("out and arrays must have the same length") @@ -2482,7 +2484,7 @@ def multi_take_put(arrays, dest_indices, src_indices, dest_shape=None, vec_count = len(arrays) if out is None: - out = [Array(queue, dest_shape, a_dtype, allocator=a_allocator) + out = [type(arrays[i])(queue, dest_shape, a_dtype, allocator=a_allocator) for i in range(vec_count)] else: if a_dtype != single_valued(o.dtype for o in out): @@ -2564,8 +2566,8 @@ def multi_put(arrays, dest_indices, dest_shape=None, out=None, queue=None, vec_count = len(arrays) if out is None: - out = [Array(queue, dest_shape, a_dtype, allocator=a_allocator) - for _ in range(vec_count)] + out = [type(arrays[i])(queue, dest_shape, a_dtype, allocator=a_allocator) + for i in range(vec_count)] else: if a_dtype != single_valued(o.dtype for o in out): raise TypeError("arrays and out must have the same dtype") @@ -2908,7 +2910,6 @@ def if_positive(criterion, then_, else_, out=None, queue=None): raise AssertionError() if out is None: - if then_.shape != (): out = empty_like( then_, criterion.queue, allocator=criterion.allocator) @@ -2918,7 +2919,8 @@ def if_positive(criterion, then_, else_, out=None, queue=None): cr_item_strides = cr_byte_strides // criterion.dtype.itemsize out_strides = tuple(cr_item_strides*then_.dtype.itemsize) - out = Array(criterion.queue, criterion.shape, then_.dtype, + out = type(criterion)( + criterion.queue, criterion.shape, then_.dtype, allocator=criterion.allocator, strides=out_strides)