diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index e16cb59f3d3d2e3f6b116f8d13a0ea93e5268f65..cac814858324b6d1908bab611cbeac47ca6b7d19 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -45,7 +45,16 @@ def add_dtypes(kernel, dtype_dict): :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype` instances """ - dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(kernel, dtype_dict) + processed_dtype_dict = {} + + for k, v in dtype_dict.items(): + for subkey in k.split(","): + subkey = subkey.strip() + if subkey: + processed_dtype_dict[subkey] = v + + dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes( + kernel, processed_dtype_dict) if dtype_dict_remainder: raise RuntimeError("unused argument dtypes: %s" @@ -104,15 +113,7 @@ def get_arguments_with_incomplete_dtype(kernel): def add_and_infer_dtypes(kernel, dtype_dict, expect_completion=False): - processed_dtype_dict = {} - - for k, v in dtype_dict.items(): - for subkey in k.split(","): - subkey = subkey.strip() - if subkey: - processed_dtype_dict[subkey] = v - - kernel = add_dtypes(kernel, processed_dtype_dict) + kernel = add_dtypes(kernel, dtype_dict) from loopy.type_inference import infer_unknown_types return infer_unknown_types(kernel, expect_completion=expect_completion)