From 3b929d63a52c06d035f13e0c926a0638fbfa543a Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992@gmail.com>
Date: Sun, 18 May 2014 09:43:10 -0400
Subject: [PATCH] add argument types to elementwise binary kernels

---
 pyopencl/elementwise.py | 110 +++++++++++++++++++++++-----------------
 1 file changed, 64 insertions(+), 46 deletions(-)

diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 0ec202ff..8d583681 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -812,74 +812,92 @@ def get_array_comparison_kernel(context, operator, dtype_a, dtype_b):
 
 
 @context_dependent_memoize
-def get_fmod_kernel(context):
-    return get_elwise_kernel(context,
-            "float *z, float *arg, float *mod",
-            "z[i] = fmod(arg[i], mod[i])",
-            name="fmod_kernel")
-
+def get_unary_func_kernel(context, func_name, in_dtype, out_dtype=None):
+    if out_dtype is None:
+        out_dtype = in_dtype
 
-@context_dependent_memoize
-def get_modf_kernel(context):
-    return get_elwise_kernel(context,
-            "float *intpart ,float *fracpart, float *x",
-            "fracpart[i] = modf(x[i], &intpart[i])",
-            name="modf_kernel")
+    return get_elwise_kernel(context, [
+        VectorArg(out_dtype, "z", with_offset=True),
+        VectorArg(in_dtype, "y", with_offset=True),
+        ],
+        "z[i] = %s(y[i])" % func_name,
+        name="%s_kernel" % func_name)
 
 
 @context_dependent_memoize
-def get_frexp_kernel(context):
-    return get_elwise_kernel(context,
-            "float *significand, float *exponent, float *x",
-            """
-                int expt = 0;
-                significand[i] = frexp(x[i], &expt);
-                exponent[i] = expt;
-            """,
-            name="frexp_kernel")
+def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype,
+                           preamble="", name=None):
+    return get_elwise_kernel(context, [
+        VectorArg(out_dtype, "z", with_offset=True),
+        VectorArg(x_dtype, "x", with_offset=True),
+        VectorArg(y_dtype, "y", with_offset=True),
+        ],
+        "z[i] = %s(x[i], y[i])" % func_name,
+        name="%s_kernel" % func_name if name is None else name,
+        preamble=preamble)
 
 
 @context_dependent_memoize
-def get_ldexp_kernel(context):
-    return get_elwise_kernel(context,
-            "float *z, float *sig, float *expt",
-            "z[i] = ldexp(sig[i], (int) expt[i])",
-            name="ldexp_kernel")
+def get_fmod_kernel(context, out_dtype=np.float32, arg_dtype=np.float32,
+                    mod_dtype=np.float32):
+    return get_binary_func_kernel(context, 'fmod', arg_dtype,
+                                  mod_dtype, out_dtype)
 
 
 @context_dependent_memoize
-def get_bessel_kernel(context, which_func):
-    return get_elwise_kernel(context,
-            "double *z, int ord_n, double *x",
-            "z[i] = bessel_%sn(ord_n, x[i])" % which_func,
-            name="bessel_%sn_kernel" % which_func,
-            preamble="""
-            #include <pyopencl-bessel-%s.cl>
-            """ % which_func)
+def get_modf_kernel(context, int_dtype=np.float32,
+                    frac_dtype=np.float32, x_dtype=np.float32):
+    return get_elwise_kernel(context, [
+        VectorArg(int_dtype, "intpart", with_offset=True),
+        VectorArg(frac_dtype, "fracpart", with_offset=True),
+        VectorArg(x_dtype, "x", with_offset=True),
+        ],
+        """
+        fracpart[i] = modf(x[i], &intpart[i])
+        """,
+        name="modf_kernel")
 
 
 @context_dependent_memoize
-def get_unary_func_kernel(context, func_name, in_dtype, out_dtype=None):
-    if out_dtype is None:
-        out_dtype = in_dtype
-
+def get_frexp_kernel(context, sign_dtype=np.float32, exp_dtype=np.float32,
+                     x_dtype=np.float32):
     return get_elwise_kernel(context, [
-        VectorArg(out_dtype, "z", with_offset=True),
-        VectorArg(in_dtype, "y", with_offset=True),
+        VectorArg(sign_dtype, "significand", with_offset=True),
+        VectorArg(exp_dtype, "exponent", with_offset=True),
+        VectorArg(x_dtype, "x", with_offset=True),
         ],
-        "z[i] = %s(y[i])" % func_name,
-        name="%s_kernel" % func_name)
+        """
+        int expt = 0;
+        significand[i] = frexp(x[i], &expt);
+        exponent[i] = expt;
+        """,
+        name="frexp_kernel")
+
+
+@context_dependent_memoize
+def get_ldexp_kernel(context, out_dtype=np.float32, sig_dtype=np.float32,
+                     expt_dtype=np.float32):
+    return get_binary_func_kernel(
+        context, '_PYOCL_LDEXP', sig_dtype, expt_dtype, out_dtype,
+        preamble="#define _PYOCL_LDEXP(x, y) ldexp(x, (int)(y))",
+        name="ldexp_kernel")
 
 
 @context_dependent_memoize
-def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype):
+def get_bessel_kernel(context, which_func, out_dtype=np.float64,
+                      order_dtype=np.int32, x_dtype=np.float64):
     return get_elwise_kernel(context, [
         VectorArg(out_dtype, "z", with_offset=True),
+        ScalarArg(order_dtype, "ord_n"),
         VectorArg(x_dtype, "x", with_offset=True),
-        VectorArg(y_dtype, "y", with_offset=True),
         ],
-        "z[i] = %s(x[i], y[i])" % func_name,
-        name="%s_kernel" % func_name)
+        "z[i] = bessel_%sn(ord_n, x[i])" % which_func,
+        name="bessel_%sn_kernel" % which_func,
+        preamble="""
+        #pragma OPENCL EXTENSION cl_khr_fp64: enable
+        #define PYOPENCL_DEFINE_CDOUBLE
+        #include <pyopencl-bessel-%s.cl>
+        """ % which_func)
 
 
 @context_dependent_memoize
-- 
GitLab