diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 28e39e40cc2919ac1f4316ee32dbca3655900ee3..6a4ac40382d7cfa89eff30d994d4738f28c00c45 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -462,7 +462,18 @@ def default_function_mangler(name, arg_dtypes): def opencl_function_mangler(name, arg_dtypes): - if name == "atan2" and len(arg_dtypes) == 2: + if name in ["max", "min"] and len(arg_dtypes) == 2: + dtype = np.find_common_type([], arg_dtypes) + + if dtype.kind == "c": + raise RuntimeError("min/max do not support complex numbers") + + if dtype.kind == "f": + name = "f" + name + + return dtype, name + + if name in "atan2" and len(arg_dtypes) == 2: return arg_dtypes[0], name if len(arg_dtypes) == 1: @@ -511,6 +522,8 @@ def opencl_symbol_mangler(name): return np.dtype(np.float32), name else: return np.dtype(np.float64), name + elif name == "INFINITY": + return np.dtype(np.float32), name else: return None