From e18443840c38db4c64a6a4335f5ee115c2a08ec6 Mon Sep 17 00:00:00 2001
From: jdsteve2 <jdsteve2@illinois.edu>
Date: Sat, 22 May 2021 18:59:31 -0500
Subject: [PATCH] Move splitting of comma-separated var string from
 add_and_infer_dtypes into add_dtypes so that it can also handle
 comma-separated strings

---
 loopy/kernel/tools.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index e16cb59f3..cac814858 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -45,7 +45,16 @@ def add_dtypes(kernel, dtype_dict):
     :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype`
         instances
     """
-    dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(kernel, dtype_dict)
+    processed_dtype_dict = {}
+
+    for k, v in dtype_dict.items():
+        for subkey in k.split(","):
+            subkey = subkey.strip()
+            if subkey:
+                processed_dtype_dict[subkey] = v
+
+    dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(
+        kernel, processed_dtype_dict)
 
     if dtype_dict_remainder:
         raise RuntimeError("unused argument dtypes: %s"
@@ -104,15 +113,7 @@ def get_arguments_with_incomplete_dtype(kernel):
 
 
 def add_and_infer_dtypes(kernel, dtype_dict, expect_completion=False):
-    processed_dtype_dict = {}
-
-    for k, v in dtype_dict.items():
-        for subkey in k.split(","):
-            subkey = subkey.strip()
-            if subkey:
-                processed_dtype_dict[subkey] = v
-
-    kernel = add_dtypes(kernel, processed_dtype_dict)
+    kernel = add_dtypes(kernel, dtype_dict)
 
     from loopy.type_inference import infer_unknown_types
     return infer_unknown_types(kernel, expect_completion=expect_completion)
-- 
GitLab