Skip to content
Snippets Groups Projects
Commit 3c04b61b authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix complex codegen logic in axpbyz given scalars are of result dtype

parent 670c7790
No related branches found
No related tags found
No related merge requests found
......@@ -508,28 +508,27 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z):
x_is_complex = dtype_x.kind == "c"
y_is_complex = dtype_y.kind == "c"
if x_is_complex:
ax = "%s_mul(a, x[i])" % complex_dtype_to_name(dtype_x)
elif not x_is_complex and y_is_complex:
ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax)
else:
ax = f"a*(({result_t}) x[i])"
if dtype_z.kind == "c":
# a and b will always be complex here.
z_ct = complex_dtype_to_name(dtype_z)
if y_is_complex:
by = "%s_mul(b, y[i])" % complex_dtype_to_name(dtype_y)
elif x_is_complex and not y_is_complex:
by = "{}_fromreal({})".format(complex_dtype_to_name(dtype_x), by)
if x_is_complex:
ax = f"{z_ct}_mul(a, {z_ct}_cast(x[i]))"
else:
ax = f"{z_ct}_mulr(a, x[i])"
if y_is_complex:
by = f"{z_ct}_mul(b, {z_ct}_cast(y[i]))"
else:
by = f"{z_ct}_mulr(b, y[i])"
result = f"{z_ct}_add({ax}, {by})"
else:
# real-only
ax = f"a*(({result_t}) x[i])"
by = f"b*(({result_t}) y[i])"
if x_is_complex or y_is_complex:
result = (
"{root}_add({root}_cast({ax}), {root}_cast({by}))"
.format(
ax=ax,
by=by,
root=complex_dtype_to_name(dtype_z)))
else:
result = f"{ax} + {by}"
return get_elwise_kernel(context,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment