From 3effd59e9920642121d48c3e5be37fd57347f359 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 26 Jun 2024 15:31:00 -0500
Subject: [PATCH] Test, fix complex-valued TypeCast in PyOpenCL

---
 loopy/target/c/codegen/expression.py |  7 +++----
 loopy/target/pyopencl.py             |  1 +
 test/test_target.py                  | 30 ++++++++++++++++++++++++++++
 3 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py
index a29325fa9..6c23fcd10 100644
--- a/loopy/target/c/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -42,6 +42,7 @@ from loopy.tools import is_integer
 from loopy.types import LoopyType
 from loopy.target.c import CExpression
 from loopy.typing import ExpressionT
+from loopy.symbolic import TypeCast
 
 
 __doc__ = """
@@ -415,10 +416,8 @@ class ExpressionToCExpressionMapper(IdentityMapper):
                     expr.operator,
                     self.rec(expr.right, inner_type_context))
 
-    def map_type_cast(self, expr, type_context):
-        registry = self.codegen_state.ast_builder.target.get_dtype_registry()
-        cast = var("(%s)" % registry.dtype_to_ctype(expr.type))
-        return cast(self.rec(expr.child, type_context))
+    def map_type_cast(self, expr: TypeCast, type_context: str):
+        return self.rec(expr.child, type_context, expr.type)
 
     def map_constant(self, expr, type_context):
         from loopy.symbolic import Literal
diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py
index d4a8944f4..6357cfda0 100644
--- a/loopy/target/pyopencl.py
+++ b/loopy/target/pyopencl.py
@@ -1,4 +1,5 @@
 from __future__ import annotations
+
 """OpenCL target integrated with PyOpenCL."""
 
 __copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
diff --git a/test/test_target.py b/test/test_target.py
index 10a04ed5b..5a41a1539 100644
--- a/test/test_target.py
+++ b/test/test_target.py
@@ -805,6 +805,36 @@ def test_ispc_private_var():
     print(cg_result.device_code())
 
 
+def test_to_complex_casts(ctx_factory):
+    arith_dtypes = "bhilqpBHILQPfdFD"
+
+    out_type = lp.to_loopy_type(np.dtype(np.complex128))
+    other = np.complex64(7)
+    from pymbolic import var
+
+    knl = lp.make_kernel(
+        [],
+        [
+            lp.Assignment(
+                f"out_{typename}",
+                lp.TypeCast(out_type, var(f"in_{typename}"))
+                +
+                lp.TypeCast(out_type, other)
+            )
+            for typename in arith_dtypes
+        ],
+        [
+            lp.GlobalArg(f"in_{typename}", dtype=np.dtype(typename), shape=())
+            for typename in arith_dtypes
+        ] + [...]
+    )
+
+    ctx = ctx_factory()
+    code = lp.generate_code_v2(knl).device_code()
+    # just testing here that the generated code builds
+    cl.Program(ctx, code).build()
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab