diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 9e6c762ad4b93c2026dd20a3e9fa0776a3aec1a3..df364eda3c883d378c1e9d25136d8f59f5763f9d 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -508,28 +508,27 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z): x_is_complex = dtype_x.kind == "c" y_is_complex = dtype_y.kind == "c" - if x_is_complex: - ax = "%s_mul(a, x[i])" % complex_dtype_to_name(dtype_x) - elif not x_is_complex and y_is_complex: - ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax) - else: - ax = f"a*(({result_t}) x[i])" + if dtype_z.kind == "c": + # a and b will always be complex here. + z_ct = complex_dtype_to_name(dtype_z) - if y_is_complex: - by = "%s_mul(b, y[i])" % complex_dtype_to_name(dtype_y) - elif x_is_complex and not y_is_complex: - by = "{}_fromreal({})".format(complex_dtype_to_name(dtype_x), by) + if x_is_complex: + ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))" + else: + ax = f"{z_ct}_mulr(a, x[i])" + + if y_is_complex: + by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))" + else: + by = f"{z_ct}_mulr(b, y[i])" + + result = f"{z_ct}_add({ax}, {by})" else: + # real-only + + ax = f"a*(({result_t}) x[i])" by = f"b*(({result_t}) y[i])" - if x_is_complex or y_is_complex: - result = ( - "{root}_add({root}_cast({ax}), {root}_cast({by}))" - .format( - ax=ax, - by=by, - root=complex_dtype_to_name(dtype_z))) - else: result = f"{ax} + {by}" return get_elwise_kernel(context,