From 6894a67379320ea25feaf15855ee499fe9b8edb0 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sat, 21 Dec 2024 16:15:04 +0200 Subject: [PATCH] mypy: fix errors from stricter cgen typing --- loopy/target/c/__init__.py | 23 ++++++++++++++++------- loopy/target/opencl.py | 9 +++------ loopy/target/pyopencl.py | 21 ++++++++++++++++----- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index e37194e4..a4990b5c 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -795,9 +795,12 @@ class CFamilyASTBuilder(ASTBuilderBase[Generable]): # {{{ code generation def get_function_definition( - self, codegen_state: CodeGenerationState, + self, + codegen_state: CodeGenerationState, codegen_result: CodeGenerationResult, - schedule_index: int, function_decl: Generable, function_body: Generable + schedule_index: int, + function_decl: Generable, + function_body: Generable ) -> Generable: kernel = codegen_state.kernel assert kernel.linearization is not None @@ -825,16 +828,23 @@ class CFamilyASTBuilder(ASTBuilderBase[Generable]): tv.initializer is not None): assert tv.read_only - decl: Generable = self.wrap_global_constant( + decl = self.wrap_global_constant( self.get_temporary_var_declarator(codegen_state, tv)) if tv.initializer is not None: - decl = Initializer(decl, generate_array_literal( + init_decl = Initializer(decl, generate_array_literal( codegen_state, tv, tv.initializer)) + else: + init_decl = decl - result.append(decl) + result.append(init_decl) + + assert isinstance(function_decl, FunctionDeclarationWrapper) + if not isinstance(function_body, Block): + function_body = Block([function_body]) fbody = FunctionBody(function_decl, function_body) + if not result: return fbody else: @@ -1338,8 +1348,7 @@ class CFunctionDeclExtractor(CASTIdentityMapper): def map_function_decl_wrapper(self, node): self.decls.append(node.subdecl) - return super()\ - .map_function_decl_wrapper(node) + return super().map_function_decl_wrapper(node) def generate_header(kernel, codegen_result=None): diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 4218ae9f..d14dd9e3 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -24,7 +24,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Literal, Sequence import numpy as np @@ -766,12 +766,9 @@ class OpenCLCASTBuilder(CFamilyASTBuilder): def get_image_arg_declarator( self, arg: ImageArg, is_written: bool) -> Declarator: - if is_written: - mode = "w" - else: - mode = "r" - from cgen.opencl import CLImage + + mode: Literal["r", "w"] = "w" if is_written else "r" return CLImage(arg.num_target_axes(), mode, arg.name) # }}} diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index d3a2373a..9add453d 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -1026,9 +1026,12 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder): # {{{ function decl/def, with arg overflow handling def get_function_definition( - self, codegen_state: CodeGenerationState, + self, + codegen_state: CodeGenerationState, codegen_result: CodeGenerationResult, - schedule_index: int, function_decl: Generable, function_body: Generable, + schedule_index: int, + function_decl: Generable, + function_body: Generable, ) -> Generable: assert isinstance(function_body, Block) kernel = codegen_state.kernel @@ -1057,15 +1060,17 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder): tv.initializer is not None): assert tv.read_only - decl: Generable = self.wrap_global_constant( + decl = self.wrap_global_constant( self.get_temporary_var_declarator(codegen_state, tv)) if tv.initializer is not None: from loopy.target.c import generate_array_literal - decl = Initializer(decl, generate_array_literal( + init_decl = Initializer(decl, generate_array_literal( codegen_state, tv, tv.initializer)) + else: + init_decl = decl - result.append(decl) + result.append(init_decl) # {{{ unpack overflow args @@ -1091,6 +1096,12 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder): # }}} + from loopy.target.c import FunctionDeclarationWrapper + + assert isinstance(function_decl, FunctionDeclarationWrapper) + if not isinstance(function_body, Block): + function_body = Block([function_body]) + fbody = FunctionBody(function_decl, function_body) if not result: return fbody -- GitLab