Skip to content
Snippets Groups Projects
Commit 75842489 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Some missing NumpyDtype wrapping

parent 8efe8657
No related branches found
No related tags found
No related merge requests found
......@@ -113,7 +113,8 @@ def opencl_function_mangler(kernel, name, arg_dtypes):
return None
if name in ["max", "min"] and len(arg_dtypes) == 2:
dtype = np.find_common_type([], arg_dtypes)
dtype = np.find_common_type(
[], [dtype.numpy_dtype for dtype in arg_dtypes])
if dtype.kind == "c":
raise RuntimeError("min/max do not support complex numbers")
......@@ -121,7 +122,7 @@ def opencl_function_mangler(kernel, name, arg_dtypes):
if dtype.kind == "f":
name = "f" + name
return dtype, name
return NumpyType(dtype), name
if name in "atan2" and len(arg_dtypes) == 2:
return arg_dtypes[0], name
......@@ -140,16 +141,16 @@ def opencl_function_mangler(kernel, name, arg_dtypes):
def opencl_symbol_mangler(kernel, name):
# FIXME: should be more picky about exact names
if name.startswith("FLT_"):
return np.dtype(np.float32), name
return NumpyType(np.dtype(np.float32)), name
elif name.startswith("DBL_"):
return np.dtype(np.float64), name
return NumpyType(np.dtype(np.float64)), name
elif name.startswith("M_"):
if name.endswith("_F"):
return np.dtype(np.float32), name
return NumpyType(np.dtype(np.float32)), name
else:
return np.dtype(np.float64), name
return NumpyType(np.dtype(np.float64)), name
elif name == "INFINITY":
return np.dtype(np.float32), name
return NumpyType(np.dtype(np.float32)), name
else:
return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment