diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 89710c02386c84ea4b37f95962269cfb4a015e07..39e4844116f10e7527f35ea34b3aaa5027a20e93 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -214,6 +214,28 @@ class OpenCLCallable(ScalarCallable): elif dtype.kind == "c": raise LoopyTypeError(f"{name} does not support type {dtype}") + return ( + self.copy(name_in_target=name, + arg_id_to_dtype={0: NumpyType(dtype), -1: + NumpyType(dtype)}), + callables_table) + elif name == "abs": + for id in arg_id_to_dtype: + if not -1 <= id <= 0: + raise LoopyError(f"'{name}' can take only one argument.") + + if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: + # the types provided aren't mature enough to specialize the + # callable + return ( + self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + + dtype = arg_id_to_dtype[0].numpy_dtype + + if dtype.kind not in ("u", "i"): + raise LoopyTypeError(f"{name} does not support type {dtype}") + return ( self.copy(name_in_target=name, arg_id_to_dtype={0: NumpyType(dtype), -1: diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 06ff41908d2eb2c2d9fd483e9d1acec5f8e98a62..7189af64ac8fc0b687c18d693771b64b4904d671 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -107,8 +107,9 @@ class PyOpenCLCallable(ScalarCallable): # function calls for floating-point parameters. numpy_dtype = dtype.numpy_dtype if numpy_dtype.kind in ("u", "i"): - dtype = NumpyType(np.float32) - if name == "abs": + if name != "abs": + dtype = NumpyType(np.float32) + elif name == "abs": name = "fabs" return ( self.copy(name_in_target=name, diff --git a/test/test_apps.py b/test/test_apps.py index 6e49e73fafae569411ad68fb8fefd24b5315087f..5e7b387dd9f0740a3328d115027b0f1aae005e04 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -697,6 +697,19 @@ def test_prefetch_through_indirect_access(): knl = lp.add_prefetch(knl, "map1[:, j]") +def test_abs_as_index(): + knl = lp.make_kernel( + ["{[i]: 0<=i<10}"], + """ + b[i] = a[abs(5-i)] + """, + [ + lp.GlobalArg("a", np.float32), + ... + ]) + print(lp.generate_code_v2(knl).device_code()) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])