diff --git a/pyopencl/array.py b/pyopencl/array.py index cdc83fd96d9edb77c6d854884279910cee0bb0c6..2e846e29de1d2c2b27663257ae115e1be80f735d 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2987,6 +2987,9 @@ def sum(a, dtype=None, queue=None, slice=None, initial=np._NoValue): if initial is not np._NoValue and not isinstance(initial, SCALAR_CLASSES): raise ValueError("'initial' is not a scalar") + if dtype is not None: + dtype = np.dtype(dtype) + from pyopencl.reduction import get_sum_kernel krnl = get_sum_kernel(a.context, dtype, a.dtype) result, event1 = krnl(a, queue=queue, slice=slice, wait_for=a.events, @@ -3018,11 +3021,16 @@ def dot(a, b, dtype=None, queue=None, slice=None): """ .. versionadded:: 2011.1 """ + if dtype is not None: + dtype = np.dtype(dtype) + from pyopencl.reduction import get_dot_kernel krnl = get_dot_kernel(a.context, dtype, a.dtype, b.dtype) + result, event1 = krnl(a, b, queue=queue, slice=slice, wait_for=a.events + b.events, return_event=True) result.add_event(event1) + return result @@ -3031,12 +3039,17 @@ def vdot(a, b, dtype=None, queue=None, slice=None): .. versionadded:: 2013.1 """ + if dtype is not None: + dtype = np.dtype(dtype) + from pyopencl.reduction import get_dot_kernel krnl = get_dot_kernel(a.context, dtype, a.dtype, b.dtype, conjugate_first=True) + result, event1 = krnl(a, b, queue=queue, slice=slice, wait_for=a.events + b.events, return_event=True) result.add_event(event1) + return result @@ -3044,12 +3057,17 @@ def subset_dot(subset, a, b, dtype=None, queue=None, slice=None): """ .. versionadded:: 2011.1 """ + if dtype is not None: + dtype = np.dtype(dtype) + from pyopencl.reduction import get_subset_dot_kernel krnl = get_subset_dot_kernel( a.context, dtype, subset.dtype, a.dtype, b.dtype) + result, event1 = krnl(subset, a, b, queue=queue, slice=slice, wait_for=subset.events + a.events + b.events, return_event=True) result.add_event(event1) + return result @@ -3130,6 +3148,9 @@ def cumsum(a, output_dtype=None, queue=None, if output_dtype is None: output_dtype = a.dtype + else: + output_dtype = np.dtype(output_dtype) + if wait_for is None: wait_for = []