diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 1d60f4c8fc7420f88c01825c81e4d3b4664cc7f9..68be1ce3bf9d579592064e4707910802a70cdbc1 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -113,7 +113,8 @@ def opencl_function_mangler(kernel, name, arg_dtypes): return None if name in ["max", "min"] and len(arg_dtypes) == 2: - dtype = np.find_common_type([], arg_dtypes) + dtype = np.find_common_type( + [], [dtype.numpy_dtype for dtype in arg_dtypes]) if dtype.kind == "c": raise RuntimeError("min/max do not support complex numbers") @@ -121,7 +122,7 @@ def opencl_function_mangler(kernel, name, arg_dtypes): if dtype.kind == "f": name = "f" + name - return dtype, name + return NumpyType(dtype), name if name in "atan2" and len(arg_dtypes) == 2: return arg_dtypes[0], name @@ -140,16 +141,16 @@ def opencl_function_mangler(kernel, name, arg_dtypes): def opencl_symbol_mangler(kernel, name): # FIXME: should be more picky about exact names if name.startswith("FLT_"): - return np.dtype(np.float32), name + return NumpyType(np.dtype(np.float32)), name elif name.startswith("DBL_"): - return np.dtype(np.float64), name + return NumpyType(np.dtype(np.float64)), name elif name.startswith("M_"): if name.endswith("_F"): - return np.dtype(np.float32), name + return NumpyType(np.dtype(np.float32)), name else: - return np.dtype(np.float64), name + return NumpyType(np.dtype(np.float64)), name elif name == "INFINITY": - return np.dtype(np.float32), name + return NumpyType(np.dtype(np.float32)), name else: return None