diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 362cbb79a31bfde38348d2b091896c75ebe8fddc..0d9532d732ac462c0fbd2163e61e26edd76b58d3 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -142,6 +142,24 @@ _CL_SIMPLE_MULTI_ARG_FUNCTIONS = { } +VECTOR_LITERAL_FUNCS = dict( + ("make_%s%d" % (name, count), (name, dtype, count)) + for name, dtype in [ + ('char', np.int8), + ('uchar', np.uint8), + ('short', np.int16), + ('ushort', np.uint16), + ('int', np.int32), + ('uint', np.uint32), + ('long', np.int64), + ('ulong', np.uint64), + ('float', np.float32), + ('double', np.float64), + ] + for count in [2, 3, 4, 8, 16] + ) + + def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None @@ -188,6 +206,19 @@ def opencl_function_mangler(kernel, name, arg_dtypes): result_dtypes=(result_dtype,), arg_dtypes=(result_dtype,)*3) + if name in VECTOR_LITERAL_FUNCS: + base_tp_name, dtype, count = VECTOR_LITERAL_FUNCS[name] + + if count != len(arg_dtypes): + return None + + from loopy.types import NumpyType + return CallMangleInfo( + target_name="(%s%d) " % (base_tp_name, count), + result_dtypes=(kernel.target.vector_dtype( + NumpyType(dtype), count),), + arg_dtypes=(NumpyType(dtype),)*count) + return None # }}}