From e97aa103067f38520e3b5a3700822232cf62f2e7 Mon Sep 17 00:00:00 2001 From: Isuru Fernando <idf2@illinois.edu> Date: Sat, 5 Mar 2022 14:52:59 -0600 Subject: [PATCH] Allow abs(int) in opencl Tweak test_abs_as_index Fix treatment of abs in OpenCLCallable --- loopy/target/opencl.py | 22 ++++++++++++++++++++++ loopy/target/pyopencl.py | 5 +++-- test/test_apps.py | 13 +++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 89710c023..39e484411 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -214,6 +214,28 @@ class OpenCLCallable(ScalarCallable): elif dtype.kind == "c": raise LoopyTypeError(f"{name} does not support type {dtype}") + return ( + self.copy(name_in_target=name, + arg_id_to_dtype={0: NumpyType(dtype), -1: + NumpyType(dtype)}), + callables_table) + elif name == "abs": + for id in arg_id_to_dtype: + if not -1 <= id <= 0: + raise LoopyError(f"'{name}' can take only one argument.") + + if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: + # the types provided aren't mature enough to specialize the + # callable + return ( + self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + + dtype = arg_id_to_dtype[0].numpy_dtype + + if dtype.kind not in ("u", "i"): + raise LoopyTypeError(f"{name} does not support type {dtype}") + return ( self.copy(name_in_target=name, arg_id_to_dtype={0: NumpyType(dtype), -1: diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 06ff41908..7189af64a 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -107,8 +107,9 @@ class PyOpenCLCallable(ScalarCallable): # function calls for floating-point parameters. numpy_dtype = dtype.numpy_dtype if numpy_dtype.kind in ("u", "i"): - dtype = NumpyType(np.float32) - if name == "abs": + if name != "abs": + dtype = NumpyType(np.float32) + elif name == "abs": name = "fabs" return ( self.copy(name_in_target=name, diff --git a/test/test_apps.py b/test/test_apps.py index 6e49e73fa..5e7b387dd 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -697,6 +697,19 @@ def test_prefetch_through_indirect_access(): knl = lp.add_prefetch(knl, "map1[:, j]") +def test_abs_as_index(): + knl = lp.make_kernel( + ["{[i]: 0<=i<10}"], + """ + b[i] = a[abs(5-i)] + """, + [ + lp.GlobalArg("a", np.float32), + ... + ]) + print(lp.generate_code_v2(knl).device_code()) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab