From e18443840c38db4c64a6a4335f5ee115c2a08ec6 Mon Sep 17 00:00:00 2001 From: jdsteve2 <jdsteve2@illinois.edu> Date: Sat, 22 May 2021 18:59:31 -0500 Subject: [PATCH] Move splitting of comma-separated var string from add_and_infer_dtypes into add_dtypes so that it can also handle comma-separated strings --- loopy/kernel/tools.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index e16cb59f3..cac814858 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -45,7 +45,16 @@ def add_dtypes(kernel, dtype_dict): :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype` instances """ - dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(kernel, dtype_dict) + processed_dtype_dict = {} + + for k, v in dtype_dict.items(): + for subkey in k.split(","): + subkey = subkey.strip() + if subkey: + processed_dtype_dict[subkey] = v + + dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes( + kernel, processed_dtype_dict) if dtype_dict_remainder: raise RuntimeError("unused argument dtypes: %s" @@ -104,15 +113,7 @@ def get_arguments_with_incomplete_dtype(kernel): def add_and_infer_dtypes(kernel, dtype_dict, expect_completion=False): - processed_dtype_dict = {} - - for k, v in dtype_dict.items(): - for subkey in k.split(","): - subkey = subkey.strip() - if subkey: - processed_dtype_dict[subkey] = v - - kernel = add_dtypes(kernel, processed_dtype_dict) + kernel = add_dtypes(kernel, dtype_dict) from loopy.type_inference import infer_unknown_types return infer_unknown_types(kernel, expect_completion=expect_completion) -- GitLab