From 3c04b61bc187179ba2a8973fb31b6341b827483c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 9 Sep 2020 17:23:13 -0500
Subject: [PATCH] Fix complex codegen logic in axpbyz given scalars are of
 result dtype

---
 pyopencl/elementwise.py | 35 +++++++++++++++++------------------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 9e6c762a..df364eda 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -508,28 +508,27 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z):
     x_is_complex = dtype_x.kind == "c"
     y_is_complex = dtype_y.kind == "c"
 
-    if x_is_complex:
-        ax = "%s_mul(a, x[i])" % complex_dtype_to_name(dtype_x)
-    elif not x_is_complex and y_is_complex:
-        ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax)
-    else:
-        ax = f"a*(({result_t}) x[i])"
+    if dtype_z.kind == "c":
+        # a and b will always be complex here.
+        z_ct = complex_dtype_to_name(dtype_z)
 
-    if y_is_complex:
-        by = "%s_mul(b, y[i])" % complex_dtype_to_name(dtype_y)
-    elif x_is_complex and not y_is_complex:
-        by = "{}_fromreal({})".format(complex_dtype_to_name(dtype_x), by)
+        if x_is_complex:
+            ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))"
+        else:
+            ax = f"{z_ct}_mulr(a, x[i])"
+
+        if y_is_complex:
+            by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))"
+        else:
+            by = f"{z_ct}_mulr(b, y[i])"
+
+        result = f"{z_ct}_add({ax}, {by})"
     else:
+        # real-only
+
+        ax = f"a*(({result_t}) x[i])"
         by = f"b*(({result_t}) y[i])"
 
-    if x_is_complex or y_is_complex:
-        result = (
-                "{root}_add({root}_cast({ax}), {root}_cast({by}))"
-                .format(
-                    ax=ax,
-                    by=by,
-                    root=complex_dtype_to_name(dtype_z)))
-    else:
         result = f"{ax} + {by}"
 
     return get_elwise_kernel(context,
-- 
GitLab