diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 3fb165ead5634d1efbf899a5a05ef9be4355d672..7b3a67c6b0b11a3adc68d58a10f309a6ee21919e 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -671,6 +671,11 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression( raise LoopyError("failed to determine type of accumulator for " "reduction '%s'" % expr) + arg_dtypes = tuple( + dt.with_target(kernel.target) + if dt is not lp.auto else dt + for dt in arg_dtypes) + reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes) reduction_dtypes = tuple( dt.with_target(kernel.target)