From c7646489c973a962e1154d0c5d5bf38c7e9d8d6d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 8 May 2016 20:15:39 -0500 Subject: [PATCH] Initial batch of CL function descriptions --- loopy/target/opencl.py | 20 ++++++++++++++++++++ test/test_loopy.py | 17 +++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index f083a70b..5cb6fbc1 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -137,6 +137,11 @@ def _register_vector_types(dtype_registry): # {{{ function mangler +_CL_SIMPLE_MULTI_ARG_FUNCTIONS = { + "clamp": 3, + } + + def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None @@ -160,6 +165,21 @@ def opencl_function_mangler(kernel, name, arg_dtypes): scalar_dtype, offset, field_name = arg_dtypes[0].numpy_dtype.fields["s0"] return NumpyType(scalar_dtype), name + if name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: + num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name] + if len(arg_dtypes) != num_args: + raise LoopyError("%s takes %d arguments (%d received)" + % (name, num_args, len(arg_dtypes))) + + dtype = np.find_common_type( + [], [dtype.numpy_dtype for dtype in arg_dtypes]) + + if dtype.kind == "c": + raise LoopyError("%s does not support complex numbers" + % name) + + return NumpyType(dtype), name + return None # }}} diff --git a/test/test_loopy.py b/test/test_loopy.py index cebbda4d..96286b73 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2473,6 +2473,23 @@ def test_atomic(ctx_factory, dtype): lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=10000)) +def test_clamp(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + n = 15 * 10**6 + x = cl.clrandom.rand(queue, n, dtype=np.float32) + + knl = lp.make_kernel( + "{ [i]: 0<=i 1: exec(sys.argv[1]) -- GitLab