From 5be8abf64543d7d24f95cddc13333ca5468a07e0 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 26 Jun 2024 17:04:37 -0500
Subject: [PATCH] C expression casting logic: refactor, add some types

---
 loopy/target/c/codegen/expression.py | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py
index 04fad17c..d6793de4 100644
--- a/loopy/target/c/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -21,6 +21,7 @@ THE SOFTWARE.
 """
 
 
+from typing import Optional
 import numpy as np
 
 from pymbolic.mapper import RecursiveMapper, IdentityMapper
@@ -109,7 +110,7 @@ class ExpressionToCExpressionMapper(IdentityMapper):
 
         return ary
 
-    def wrap_in_typecast(self, actual_type, needed_type, s):
+    def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s):
         if actual_type != needed_type:
             registry = self.codegen_state.ast_builder.target.get_dtype_registry()
             cast = var("(%s) " % registry.dtype_to_ctype(needed_type))
@@ -117,13 +118,15 @@ class ExpressionToCExpressionMapper(IdentityMapper):
 
         return s
 
-    def rec(self, expr, type_context=None, needed_type=None):
-        if needed_type is None:
-            return RecursiveMapper.rec(self, expr, type_context)
+    def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None):
+        result = RecursiveMapper.rec(self, expr, type_context)
 
-        return self.wrap_in_typecast(
-                self.infer_type(expr), needed_type,
-                RecursiveMapper.rec(self, expr, type_context))
+        if needed_type is None:
+            return result
+        else:
+            return self.wrap_in_typecast(
+                    self.infer_type(expr), needed_type,
+                    result)
 
     def __call__(self, expr, prec=None, type_context=None, needed_dtype=None):
         if prec is None:
-- 
GitLab