diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 36dc4b4fb086d08247fc2eae86e21f79b5b117dd..781947fa5176a7df899bce586e07677e72728d86 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -110,7 +110,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] if np.isscalar(arg): arg_id = (iarg,) arg_id_to_arg[arg_id] = arg - arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(arg)) + arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg))) elif is_array_container(arg): def id_collector(keys, ary): arg_id = (iarg,) + keys @@ -136,7 +136,7 @@ def _get_f_placeholder_args(arg, iarg, arg_id_to_name): """ if np.isscalar(arg): name = arg_id_to_name[(iarg,)] - return pt.make_placeholder(name, (), np.dtype(arg)) + return pt.make_placeholder(name, (), np.dtype(type(arg))) elif is_array_container(arg): def _rec_to_placeholder(keys, ary): name = arg_id_to_name[(iarg,) + keys]