From b71b4f8825b7dba2735ff7f4c8c4b99a5123b728 Mon Sep 17 00:00:00 2001
From: Isuru Fernando <idf2@illinois.edu>
Date: Sun, 6 Mar 2022 18:14:05 -0600
Subject: [PATCH] check for unsigned dtype for loopy_mod (#563)

* check for unsigned dtype for loopy_mod

* fix formatting

* simplify test
---
 loopy/target/c/codegen/expression.py |  8 ++++++--
 test/test_apps.py                    | 12 ++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py
index f54007e7e..4a94d1cf3 100644
--- a/loopy/target/c/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -332,9 +332,13 @@ class ExpressionToCExpressionMapper(IdentityMapper):
         assumptions, domain = isl.align_two(assumption_non_param, domain)
         domain = domain & assumptions
 
+        num_type = self.infer_type(expr.numerator)
+        den_type = self.infer_type(expr.denominator)
         from loopy.isl_helpers import is_nonnegative
-        num_nonneg = is_nonnegative(expr.numerator, domain)
-        den_nonneg = is_nonnegative(expr.denominator, domain)
+        num_nonneg = is_nonnegative(expr.numerator, domain) \
+            or num_type.numpy_dtype.kind == "u"
+        den_nonneg = is_nonnegative(expr.denominator, domain) \
+            or den_type.numpy_dtype.kind == "u"
 
         result_dtype = self.infer_type(expr)
         suffix = result_dtype.numpy_dtype.type.__name__
diff --git a/test/test_apps.py b/test/test_apps.py
index 5e7b387dd..438a5945c 100644
--- a/test/test_apps.py
+++ b/test/test_apps.py
@@ -697,6 +697,18 @@ def test_prefetch_through_indirect_access():
         knl = lp.add_prefetch(knl, "map1[:, j]")
 
 
+def test_unsigned_types_to_mod():
+    knl = lp.make_kernel("{[i]: 0<=i<10}",
+        """
+            <> c = b[i] {id=init,dup=i}
+            a[i] = i % c {dep=init}
+        """,
+        [lp.GlobalArg("a", shape=(10,), dtype=np.uint32),
+         lp.GlobalArg("b", shape=(10,), dtype=np.uint32)]
+    )
+    assert "loopy_mod" not in lp.generate_code_v2(knl).device_code()
+
+
 def test_abs_as_index():
     knl = lp.make_kernel(
         ["{[i]: 0<=i<10}"],
-- 
GitLab