diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index f083a70b0288d88d544285e0063eb2e677dfefcf..5cb6fbc19be90e68e60b914e8b96fc739df44439 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -137,6 +137,11 @@ def _register_vector_types(dtype_registry): # {{{ function mangler +_CL_SIMPLE_MULTI_ARG_FUNCTIONS = { + "clamp": 3, + } + + def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None @@ -160,6 +165,21 @@ def opencl_function_mangler(kernel, name, arg_dtypes): scalar_dtype, offset, field_name = arg_dtypes[0].numpy_dtype.fields["s0"] return NumpyType(scalar_dtype), name + if name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: + num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name] + if len(arg_dtypes) != num_args: + raise LoopyError("%s takes %d arguments (%d received)" + % (name, num_args, len(arg_dtypes))) + + dtype = np.find_common_type( + [], [dtype.numpy_dtype for dtype in arg_dtypes]) + + if dtype.kind == "c": + raise LoopyError("%s does not support complex numbers" + % name) + + return NumpyType(dtype), name + return None # }}} diff --git a/test/test_loopy.py b/test/test_loopy.py index cebbda4d9b9177bd0a9f63ff9b3126b3c188c76c..96286b73df5642ef98b04f1a7acb03c60b8d5e7c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2473,6 +2473,23 @@ def test_atomic(ctx_factory, dtype): lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=10000)) +def test_clamp(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + n = 15 * 10**6 + x = cl.clrandom.rand(queue, n, dtype=np.float32) + + knl = lp.make_kernel( + "{ [i]: 0<=i<n }", + "out[i] = clamp(x[i], a, b)") + + knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") + knl = lp.set_options(knl, write_cl=True) + + evt, (out,) = knl(queue, x=x, a=np.float32(12), b=np.float32(15)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])