From e97aa103067f38520e3b5a3700822232cf62f2e7 Mon Sep 17 00:00:00 2001
From: Isuru Fernando <idf2@illinois.edu>
Date: Sat, 5 Mar 2022 14:52:59 -0600
Subject: [PATCH] Allow abs(int) in opencl

Tweak test_abs_as_index

Fix treatment of abs in OpenCLCallable
---
 loopy/target/opencl.py   | 22 ++++++++++++++++++++++
 loopy/target/pyopencl.py |  5 +++--
 test/test_apps.py        | 13 +++++++++++++
 3 files changed, 38 insertions(+), 2 deletions(-)

diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index 89710c023..39e484411 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -214,6 +214,28 @@ class OpenCLCallable(ScalarCallable):
             elif dtype.kind == "c":
                 raise LoopyTypeError(f"{name} does not support type {dtype}")
 
+            return (
+                    self.copy(name_in_target=name,
+                        arg_id_to_dtype={0: NumpyType(dtype), -1:
+                            NumpyType(dtype)}),
+                    callables_table)
+        elif name == "abs":
+            for id in arg_id_to_dtype:
+                if not -1 <= id <= 0:
+                    raise LoopyError(f"'{name}' can take only one argument.")
+
+            if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None:
+                # the types provided aren't mature enough to specialize the
+                # callable
+                return (
+                        self.copy(arg_id_to_dtype=arg_id_to_dtype),
+                        callables_table)
+
+            dtype = arg_id_to_dtype[0].numpy_dtype
+
+            if dtype.kind not in ("u", "i"):
+                raise LoopyTypeError(f"{name} does not support type {dtype}")
+
             return (
                     self.copy(name_in_target=name,
                         arg_id_to_dtype={0: NumpyType(dtype), -1:
diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py
index 06ff41908..7189af64a 100644
--- a/loopy/target/pyopencl.py
+++ b/loopy/target/pyopencl.py
@@ -107,8 +107,9 @@ class PyOpenCLCallable(ScalarCallable):
                 # function calls for floating-point parameters.
                 numpy_dtype = dtype.numpy_dtype
                 if numpy_dtype.kind in ("u", "i"):
-                    dtype = NumpyType(np.float32)
-                if name == "abs":
+                    if name != "abs":
+                        dtype = NumpyType(np.float32)
+                elif name == "abs":
                     name = "fabs"
                 return (
                         self.copy(name_in_target=name,
diff --git a/test/test_apps.py b/test/test_apps.py
index 6e49e73fa..5e7b387dd 100644
--- a/test/test_apps.py
+++ b/test/test_apps.py
@@ -697,6 +697,19 @@ def test_prefetch_through_indirect_access():
         knl = lp.add_prefetch(knl, "map1[:, j]")
 
 
+def test_abs_as_index():
+    knl = lp.make_kernel(
+        ["{[i]: 0<=i<10}"],
+        """
+        b[i] = a[abs(5-i)]
+        """,
+        [
+            lp.GlobalArg("a", np.float32),
+            ...
+            ])
+    print(lp.generate_code_v2(knl).device_code())
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab