From 6894a67379320ea25feaf15855ee499fe9b8edb0 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sat, 21 Dec 2024 16:15:04 +0200
Subject: [PATCH] mypy: fix errors from stricter cgen typing

---
 loopy/target/c/__init__.py | 23 ++++++++++++++++-------
 loopy/target/opencl.py     |  9 +++------
 loopy/target/pyopencl.py   | 21 ++++++++++++++++-----
 3 files changed, 35 insertions(+), 18 deletions(-)

diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index e37194e4..a4990b5c 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -795,9 +795,12 @@ class CFamilyASTBuilder(ASTBuilderBase[Generable]):
     # {{{ code generation
 
     def get_function_definition(
-            self, codegen_state: CodeGenerationState,
+            self,
+            codegen_state: CodeGenerationState,
             codegen_result: CodeGenerationResult,
-            schedule_index: int, function_decl: Generable, function_body: Generable
+            schedule_index: int,
+            function_decl: Generable,
+            function_body: Generable
             ) -> Generable:
         kernel = codegen_state.kernel
         assert kernel.linearization is not None
@@ -825,16 +828,23 @@ class CFamilyASTBuilder(ASTBuilderBase[Generable]):
                         tv.initializer is not None):
                     assert tv.read_only
 
-                    decl: Generable = self.wrap_global_constant(
+                    decl = self.wrap_global_constant(
                             self.get_temporary_var_declarator(codegen_state, tv))
 
                     if tv.initializer is not None:
-                        decl = Initializer(decl, generate_array_literal(
+                        init_decl = Initializer(decl, generate_array_literal(
                             codegen_state, tv, tv.initializer))
+                    else:
+                        init_decl = decl
 
-                    result.append(decl)
+                    result.append(init_decl)
+
+        assert isinstance(function_decl, FunctionDeclarationWrapper)
+        if not isinstance(function_body, Block):
+            function_body = Block([function_body])
 
         fbody = FunctionBody(function_decl, function_body)
+
         if not result:
             return fbody
         else:
@@ -1338,8 +1348,7 @@ class CFunctionDeclExtractor(CASTIdentityMapper):
 
     def map_function_decl_wrapper(self, node):
         self.decls.append(node.subdecl)
-        return super()\
-                .map_function_decl_wrapper(node)
+        return super().map_function_decl_wrapper(node)
 
 
 def generate_header(kernel, codegen_result=None):
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index 4218ae9f..d14dd9e3 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -24,7 +24,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Literal, Sequence
 
 import numpy as np
 
@@ -766,12 +766,9 @@ class OpenCLCASTBuilder(CFamilyASTBuilder):
 
     def get_image_arg_declarator(
             self, arg: ImageArg, is_written: bool) -> Declarator:
-        if is_written:
-            mode = "w"
-        else:
-            mode = "r"
-
         from cgen.opencl import CLImage
+
+        mode: Literal["r", "w"] = "w" if is_written else "r"
         return CLImage(arg.num_target_axes(), mode, arg.name)
 
     # }}}
diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py
index d3a2373a..9add453d 100644
--- a/loopy/target/pyopencl.py
+++ b/loopy/target/pyopencl.py
@@ -1026,9 +1026,12 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder):
     # {{{ function decl/def, with arg overflow handling
 
     def get_function_definition(
-            self, codegen_state: CodeGenerationState,
+            self,
+            codegen_state: CodeGenerationState,
             codegen_result: CodeGenerationResult,
-            schedule_index: int, function_decl: Generable, function_body: Generable,
+            schedule_index: int,
+            function_decl: Generable,
+            function_body: Generable,
             ) -> Generable:
         assert isinstance(function_body, Block)
         kernel = codegen_state.kernel
@@ -1057,15 +1060,17 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder):
                         tv.initializer is not None):
                     assert tv.read_only
 
-                    decl: Generable = self.wrap_global_constant(
+                    decl = self.wrap_global_constant(
                             self.get_temporary_var_declarator(codegen_state, tv))
 
                     if tv.initializer is not None:
                         from loopy.target.c import generate_array_literal
-                        decl = Initializer(decl, generate_array_literal(
+                        init_decl = Initializer(decl, generate_array_literal(
                             codegen_state, tv, tv.initializer))
+                    else:
+                        init_decl = decl
 
-                    result.append(decl)
+                    result.append(init_decl)
 
         # {{{ unpack overflow args
 
@@ -1091,6 +1096,12 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder):
 
         # }}}
 
+        from loopy.target.c import FunctionDeclarationWrapper
+
+        assert isinstance(function_decl, FunctionDeclarationWrapper)
+        if not isinstance(function_body, Block):
+            function_body = Block([function_body])
+
         fbody = FunctionBody(function_decl, function_body)
         if not result:
             return fbody
-- 
GitLab