From 5d5c660dee4518c8395c1bdde7e3c5b7b1d34f18 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 19 Jul 2021 16:12:39 -0500
Subject: [PATCH] improves scalar's dtype guess

---
 arraycontext/impl/pytato/compile.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 36dc4b4..781947f 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -110,7 +110,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...]
         if np.isscalar(arg):
             arg_id = (iarg,)
             arg_id_to_arg[arg_id] = arg
-            arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(arg))
+            arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
         elif is_array_container(arg):
             def id_collector(keys, ary):
                 arg_id = (iarg,) + keys
@@ -136,7 +136,7 @@ def _get_f_placeholder_args(arg, iarg, arg_id_to_name):
     """
     if np.isscalar(arg):
         name = arg_id_to_name[(iarg,)]
-        return pt.make_placeholder(name, (), np.dtype(arg))
+        return pt.make_placeholder(name, (), np.dtype(type(arg)))
     elif is_array_container(arg):
         def _rec_to_placeholder(keys, ary):
             name = arg_id_to_name[(iarg,) + keys]
-- 
GitLab