Skip to content
Snippets Groups Projects
Commit d3d89f62 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make dot work for all real/complex type combinations.

parent 70ec9e1d
No related branches found
No related tags found
No related merge requests found
......@@ -393,35 +393,54 @@ def get_sum_kernel(ctx, dtype_out, dtype_in):
@context_dependent_memoize
def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None):
if dtype_out is None:
dtype_out = dtype_a
if dtype_b is None:
if dtype_a is None:
dtype_b = dtype_out
else:
dtype_b = dtype_a
if dtype_a is None:
dtype_a = dtype_out
if dtype_out is None:
from pyopencl.compyte.array import get_common_dtype
from pyopencl.characterize import has_double_support
dtype_out = get_common_dtype(
dtype_a.type(0), dtype_b.type(0), has_double_support(ctx.devices[0]))
a_real_dtype = dtype_a.type(0).real.dtype
b_real_dtype = dtype_b.type(0).real.dtype
out_real_dtype = dtype_out.type(0).real.dtype
a_is_complex = dtype_a.kind == "c"
b_is_complex = dtype_b.kind == "c"
out_is_complex = dtype_out.kind == "c"
if out_is_complex:
from pyopencl.elementwise import complex_dtype_to_name
if a_is_complex and b_is_complex:
a = "a[i]"
b = "b[i]"
from pyopencl.elementwise import complex_dtype_to_name
if a_is_complex and dtype_a != dtype_out:
if dtype_a != dtype_out:
a = "%s_cast(%s)" % (complex_dtype_to_name(dtype_out), a)
if b_is_complex and dtype_b != dtype_out:
if dtype_b != dtype_out:
b = "%s_cast(%s)" % (complex_dtype_to_name(dtype_out), b)
map_expr = "%s_mul(%s, %s)" % (
complex_dtype_to_name(dtype_out), a, b)
else:
map_expr = "a[i]*b[i]"
a = "a[i]"
b = "b[i]"
if out_is_complex:
if a_is_complex and dtype_a != dtype_out:
a = "%s_cast(%s)" % (complex_dtype_to_name(dtype_out), a)
if b_is_complex and dtype_b != dtype_out:
b = "%s_cast(%s)" % (complex_dtype_to_name(dtype_out), b)
if not a_is_complex and a_real_dtype != out_real_dtype:
a = "(%s) (%s)" % (dtype_to_ctype(out_real_dtype), a)
if not b_is_complex and b_real_dtype != out_real_dtype:
b = "(%s) (%s)" % (dtype_to_ctype(out_real_dtype), b)
map_expr = "%s*%s" % (a, b)
return ReductionKernel(ctx, dtype_out, neutral="0",
reduce_expr="a+b", map_expr=map_expr,
......
......@@ -607,17 +607,23 @@ def test_dot(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)
for dtype in [np.float32, np.complex64]:
a_gpu = general_clrand(queue, (200000,), dtype)
a = a_gpu.get()
b_gpu = general_clrand(queue, (200000,), dtype)
b = b_gpu.get()
dtypes = [np.float32, np.complex64]
if has_double_support(context.devices[0]):
dtypes.extend([np.float64, np.complex128])
for a_dtype in dtypes:
for b_dtype in dtypes:
print a_dtype, b_dtype
a_gpu = general_clrand(queue, (200000,), a_dtype)
a = a_gpu.get()
b_gpu = general_clrand(queue, (200000,), b_dtype)
b = b_gpu.get()
dot_ab = np.dot(a, b)
dot_ab = np.dot(a, b)
dot_ab_gpu = cl_array.dot(a_gpu, b_gpu).get()
dot_ab_gpu = cl_array.dot(a_gpu, b_gpu).get()
assert abs(dot_ab_gpu - dot_ab) / abs(dot_ab) < 1e-4
assert abs(dot_ab_gpu - dot_ab) / abs(dot_ab) < 1e-4
if False:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment