From 8a2f855a7f358870826ee72d94925d2a68ce7d95 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 26 Feb 2018 00:27:13 -0600
Subject: [PATCH] Fix, test half-complex-half-not conditionals

---
 loopy/target/c/codegen/expression.py |  5 +++--
 loopy/version.py                     |  2 +-
 test/test_loopy.py                   | 13 +++++++++++++
 3 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py
index c111a02b7..59ed77f9c 100644
--- a/loopy/target/c/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -324,10 +324,11 @@ class ExpressionToCExpressionMapper(IdentityMapper):
                     self.rec(expr.denominator, 'i'))
 
     def map_if(self, expr, type_context):
+        result_type = self.infer_type(expr)
         return type(expr)(
                 self.rec(expr.condition, "i"),
-                self.rec(expr.then, type_context),
-                self.rec(expr.else_, type_context),
+                self.rec(expr.then, type_context, result_type),
+                self.rec(expr.else_, type_context, result_type),
                 )
 
     def map_comparison(self, expr, type_context):
diff --git a/loopy/version.py b/loopy/version.py
index aeb0b277a..b3033c3a9 100644
--- a/loopy/version.py
+++ b/loopy/version.py
@@ -32,7 +32,7 @@ except ImportError:
 else:
     _islpy_version = islpy.version.VERSION_TEXT
 
-DATA_MODEL_VERSION = "v77-islpy%s" % _islpy_version
+DATA_MODEL_VERSION = "v78-islpy%s" % _islpy_version
 
 
 FALLBACK_LANGUAGE_VERSION = (2017, 2, 1)
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 8581ae5b8..86351dd93 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2856,6 +2856,19 @@ def test_no_barriers_for_nonoverlapping_access(second_index, expect_barrier):
     assert barrier_between(knl, "first", "second") == expect_barrier
 
 
+def test_half_complex_conditional(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+            "{[i]: 0 <= i < 10}",
+            """
+           tmp[i] = if(i < 5, 0, 0j)
+           """)
+
+    knl(queue)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab