Skip to content
Snippets Groups Projects
Commit 477c420b authored by Alexandru Fikl's avatar Alexandru Fikl Committed by Andreas Klöckner
Browse files

ensure dtype is np.dtype in reductions

parent 4f41f14b
No related branches found
No related tags found
No related merge requests found
......@@ -2987,6 +2987,9 @@ def sum(a, dtype=None, queue=None, slice=None, initial=np._NoValue):
if initial is not np._NoValue and not isinstance(initial, SCALAR_CLASSES):
raise ValueError("'initial' is not a scalar")
if dtype is not None:
dtype = np.dtype(dtype)
from pyopencl.reduction import get_sum_kernel
krnl = get_sum_kernel(a.context, dtype, a.dtype)
result, event1 = krnl(a, queue=queue, slice=slice, wait_for=a.events,
......@@ -3018,11 +3021,16 @@ def dot(a, b, dtype=None, queue=None, slice=None):
"""
.. versionadded:: 2011.1
"""
if dtype is not None:
dtype = np.dtype(dtype)
from pyopencl.reduction import get_dot_kernel
krnl = get_dot_kernel(a.context, dtype, a.dtype, b.dtype)
result, event1 = krnl(a, b, queue=queue, slice=slice,
wait_for=a.events + b.events, return_event=True)
result.add_event(event1)
return result
......@@ -3031,12 +3039,17 @@ def vdot(a, b, dtype=None, queue=None, slice=None):
.. versionadded:: 2013.1
"""
if dtype is not None:
dtype = np.dtype(dtype)
from pyopencl.reduction import get_dot_kernel
krnl = get_dot_kernel(a.context, dtype, a.dtype, b.dtype,
conjugate_first=True)
result, event1 = krnl(a, b, queue=queue, slice=slice,
wait_for=a.events + b.events, return_event=True)
result.add_event(event1)
return result
......@@ -3044,12 +3057,17 @@ def subset_dot(subset, a, b, dtype=None, queue=None, slice=None):
"""
.. versionadded:: 2011.1
"""
if dtype is not None:
dtype = np.dtype(dtype)
from pyopencl.reduction import get_subset_dot_kernel
krnl = get_subset_dot_kernel(
a.context, dtype, subset.dtype, a.dtype, b.dtype)
result, event1 = krnl(subset, a, b, queue=queue, slice=slice,
wait_for=subset.events + a.events + b.events, return_event=True)
result.add_event(event1)
return result
......@@ -3130,6 +3148,9 @@ def cumsum(a, output_dtype=None, queue=None,
if output_dtype is None:
output_dtype = a.dtype
else:
output_dtype = np.dtype(output_dtype)
if wait_for is None:
wait_for = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment