From 3effd59e9920642121d48c3e5be37fd57347f359 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 26 Jun 2024 15:31:00 -0500 Subject: [PATCH] Test, fix complex-valued TypeCast in PyOpenCL --- loopy/target/c/codegen/expression.py | 7 +++---- loopy/target/pyopencl.py | 1 + test/test_target.py | 30 ++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index a29325fa9..6c23fcd10 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -42,6 +42,7 @@ from loopy.tools import is_integer from loopy.types import LoopyType from loopy.target.c import CExpression from loopy.typing import ExpressionT +from loopy.symbolic import TypeCast __doc__ = """ @@ -415,10 +416,8 @@ class ExpressionToCExpressionMapper(IdentityMapper): expr.operator, self.rec(expr.right, inner_type_context)) - def map_type_cast(self, expr, type_context): - registry = self.codegen_state.ast_builder.target.get_dtype_registry() - cast = var("(%s)" % registry.dtype_to_ctype(expr.type)) - return cast(self.rec(expr.child, type_context)) + def map_type_cast(self, expr: TypeCast, type_context: str): + return self.rec(expr.child, type_context, expr.type) def map_constant(self, expr, type_context): from loopy.symbolic import Literal diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index d4a8944f4..6357cfda0 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -1,4 +1,5 @@ from __future__ import annotations + """OpenCL target integrated with PyOpenCL.""" __copyright__ = "Copyright (C) 2015 Andreas Kloeckner" diff --git a/test/test_target.py b/test/test_target.py index 10a04ed5b..5a41a1539 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -805,6 +805,36 @@ def test_ispc_private_var(): print(cg_result.device_code()) +def test_to_complex_casts(ctx_factory): + arith_dtypes = "bhilqpBHILQPfdFD" + + out_type = lp.to_loopy_type(np.dtype(np.complex128)) + other = np.complex64(7) + from pymbolic import var + + knl = lp.make_kernel( + [], + [ + lp.Assignment( + f"out_{typename}", + lp.TypeCast(out_type, var(f"in_{typename}")) + + + lp.TypeCast(out_type, other) + ) + for typename in arith_dtypes + ], + [ + lp.GlobalArg(f"in_{typename}", dtype=np.dtype(typename), shape=()) + for typename in arith_dtypes + ] + [...] + ) + + ctx = ctx_factory() + code = lp.generate_code_v2(knl).device_code() + # just testing here that the generated code builds + cl.Program(ctx, code).build() + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab