diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 6c22572128717394beb90a2c19752b52b58acf72..7562f485cfb2809ad9dc1fc4cdbec45c269c97d8 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -1050,10 +1050,10 @@ class ExclusiveScanKernel(_ScanKernelBase): # {{{ higher-level trickery @context_dependent_memoize -def get_copy_if_kernel(ctx, dtype, predicate): +def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype): ctype = dtype_to_ctype(dtype) return GenericScanKernel( - ctx, np.uint32, + ctx, dtype, arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype), input_expr="(%s) ? 1 : 0" % predicate, scan_expr="a+b", neutral="0", @@ -1064,10 +1064,12 @@ def get_copy_if_kernel(ctx, dtype, predicate): ) def copy_if(ary, predicate, queue=None): - # FIXME use 64-bit scan, eventually - # (not relevant for 6GB GPUs) + if len(ary) > np.iinfo(np.uint32): + scan_dtype = np.uint64 + else: + scan_dtype = np.uint32 - knl = get_copy_if_kernel(ary.context, ary.dtype, predicate) + knl = get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype) out = cl_array.empty_like(ary) count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) knl(ary, out, count, queue=queue)