diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 7189af64ac8fc0b687c18d693771b64b4904d671..309c77fbeef538ce147987953e566c0c4b80c56b 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -89,7 +89,7 @@ class PyOpenCLCallable(ScalarCallable): if name in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh", - "conj", "abs"]: + "conj"]: if dtype.is_complex(): # function parameters are complex. if dtype.numpy_dtype == np.complex64: @@ -103,22 +103,15 @@ class PyOpenCLCallable(ScalarCallable): self.copy(name_in_target=f"{tpname}_{name}", arg_id_to_dtype={0: dtype, -1: dtype}), callables_table) - else: - # function calls for floating-point parameters. - numpy_dtype = dtype.numpy_dtype - if numpy_dtype.kind in ("u", "i"): - if name != "abs": - dtype = NumpyType(np.float32) - elif name == "abs": - name = "fabs" - return ( - self.copy(name_in_target=name, - arg_id_to_dtype={0: dtype, -1: dtype}), - callables_table) - return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), - callables_table) + # fall back to pure OpenCL for real-valued arguments + + from loopy.target.opencl import OpenCLCallable + return OpenCLCallable(name, + arg_id_to_dtype=self.arg_id_to_dtype, + arg_id_to_descr=self.arg_id_to_descr, + name_in_target=self.name_in_target).with_types( + arg_id_to_dtype, callables_table) def generate_preambles(self, target): name = self.name_in_target