diff --git a/src/cl/pyopencl-complex.h b/src/cl/pyopencl-complex.h index 57e3a963f468a55b8bc01387611f0cf166ea66b4..005ffab3dd68e5d08b8b48df02de9a5dd66a6c7b 100644 --- a/src/cl/pyopencl-complex.h +++ b/src/cl/pyopencl-complex.h @@ -24,8 +24,6 @@ // functions as visible below, e.g. cdouble_log(z). // // Under the hood, the complex types are simply float2 and double2. -// Note that addition (real + complex) and multiplication (complex*complex) -// are defined, but yield wrong results. #define PYOPENCL_DECLARE_COMPLEX_TYPE_INT(REAL_TP, REAL_3LTR, TPROOT, TP) \ \ @@ -36,10 +34,18 @@ TP TPROOT##_fromreal(REAL_TP a) { return (TP)(a, 0); } \ TP TPROOT##_conj(TP a) { return (TP)(a.x, -a.y); } \ \ + TP TPROOT##_add(TP a, TP b) \ + { \ + return a+b; \ + } \ TP TPROOT##_addr(TP a, REAL_TP b) \ { \ return (TP)(b+a.x, a.y); \ } \ + TP TPROOT##_radd(REAL_TP a, TP b) \ + { \ + return (TP)(a+b.x, b.y); \ + } \ \ TP TPROOT##_mul(TP a, TP b) \ { \ @@ -53,6 +59,11 @@ return a*b; \ } \ \ + TP TPROOT##_rmul(REAL_TP a, TP b) \ + { \ + return a*b; \ + } \ + \ TP TPROOT##_rdivide(REAL_TP z1, TP z2) \ { \ if (fabs(z2.x) <= fabs(z2.y)) { \