diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 9e6c762ad4b93c2026dd20a3e9fa0776a3aec1a3..df364eda3c883d378c1e9d25136d8f59f5763f9d 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -508,28 +508,27 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z):
     x_is_complex = dtype_x.kind == "c"
     y_is_complex = dtype_y.kind == "c"
 
-    if x_is_complex:
-        ax = "%s_mul(a, x[i])" % complex_dtype_to_name(dtype_x)
-    elif not x_is_complex and y_is_complex:
-        ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax)
-    else:
-        ax = f"a*(({result_t}) x[i])"
+    if dtype_z.kind == "c":
+        # a and b will always be complex here.
+        z_ct = complex_dtype_to_name(dtype_z)
 
-    if y_is_complex:
-        by = "%s_mul(b, y[i])" % complex_dtype_to_name(dtype_y)
-    elif x_is_complex and not y_is_complex:
-        by = "{}_fromreal({})".format(complex_dtype_to_name(dtype_x), by)
+        if x_is_complex:
+            ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))"
+        else:
+            ax = f"{z_ct}_mulr(a, x[i])"
+
+        if y_is_complex:
+            by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))"
+        else:
+            by = f"{z_ct}_mulr(b, y[i])"
+
+        result = f"{z_ct}_add({ax}, {by})"
     else:
+        # real-only
+
+        ax = f"a*(({result_t}) x[i])"
         by = f"b*(({result_t}) y[i])"
 
-    if x_is_complex or y_is_complex:
-        result = (
-                "{root}_add({root}_cast({ax}), {root}_cast({by}))"
-                .format(
-                    ax=ax,
-                    by=by,
-                    root=complex_dtype_to_name(dtype_z)))
-    else:
         result = f"{ax} + {by}"
 
     return get_elwise_kernel(context,