From 5d5c660dee4518c8395c1bdde7e3c5b7b1d34f18 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 19 Jul 2021 16:12:39 -0500 Subject: [PATCH] improves scalar's dtype guess --- arraycontext/impl/pytato/compile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 36dc4b4..781947f 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -110,7 +110,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] if np.isscalar(arg): arg_id = (iarg,) arg_id_to_arg[arg_id] = arg - arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(arg)) + arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg))) elif is_array_container(arg): def id_collector(keys, ary): arg_id = (iarg,) + keys @@ -136,7 +136,7 @@ def _get_f_placeholder_args(arg, iarg, arg_id_to_name): """ if np.isscalar(arg): name = arg_id_to_name[(iarg,)] - return pt.make_placeholder(name, (), np.dtype(arg)) + return pt.make_placeholder(name, (), np.dtype(type(arg))) elif is_array_container(arg): def _rec_to_placeholder(keys, ary): name = arg_id_to_name[(iarg,) + keys] -- GitLab