From 3c04b61bc187179ba2a8973fb31b6341b827483c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 9 Sep 2020 17:23:13 -0500 Subject: [PATCH] Fix complex codegen logic in axpbyz given scalars are of result dtype --- pyopencl/elementwise.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 9e6c762a..df364eda 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -508,28 +508,27 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z): x_is_complex = dtype_x.kind == "c" y_is_complex = dtype_y.kind == "c" - if x_is_complex: - ax = "%s_mul(a, x[i])" % complex_dtype_to_name(dtype_x) - elif not x_is_complex and y_is_complex: - ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax) - else: - ax = f"a*(({result_t}) x[i])" + if dtype_z.kind == "c": + # a and b will always be complex here. + z_ct = complex_dtype_to_name(dtype_z) - if y_is_complex: - by = "%s_mul(b, y[i])" % complex_dtype_to_name(dtype_y) - elif x_is_complex and not y_is_complex: - by = "{}_fromreal({})".format(complex_dtype_to_name(dtype_x), by) + if x_is_complex: + ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))" + else: + ax = f"{z_ct}_mulr(a, x[i])" + + if y_is_complex: + by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))" + else: + by = f"{z_ct}_mulr(b, y[i])" + + result = f"{z_ct}_add({ax}, {by})" else: + # real-only + + ax = f"a*(({result_t}) x[i])" by = f"b*(({result_t}) y[i])" - if x_is_complex or y_is_complex: - result = ( - "{root}_add({root}_cast({ax}), {root}_cast({by}))" - .format( - ax=ax, - by=by, - root=complex_dtype_to_name(dtype_z))) - else: result = f"{ax} + {by}" return get_elwise_kernel(context, -- GitLab