diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 14d51f2091ff39cc605e62ac0fca5f57f128ee48..e0669661403cd9b77e51908f9c1ee3f3251688db 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -139,6 +139,7 @@ def _register_vector_types(dtype_registry): _CL_SIMPLE_MULTI_ARG_FUNCTIONS = { "clamp": 3, + "atan2": 2, } @@ -164,7 +165,7 @@ def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None - if name in ["max", "min", "atan2"] and len(arg_dtypes) == 2: + if name in ["max", "min"] and len(arg_dtypes) == 2: dtype = np.find_common_type( [], [dtype.numpy_dtype for dtype in arg_dtypes]) @@ -204,7 +205,7 @@ def opencl_function_mangler(kernel, name, arg_dtypes): return CallMangleInfo( target_name=name, result_dtypes=(result_dtype,), - arg_dtypes=(result_dtype,)*3) + arg_dtypes=(result_dtype,)*num_args) if name in VECTOR_LITERAL_FUNCS: base_tp_name, dtype, count = VECTOR_LITERAL_FUNCS[name]