diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index e77c10913488f531fa45d31812740d6884d2e0bf..90449f0a6136294ee21d4b8f96890110343be2b8 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
     :attr:`BaseLazilyCompilingFunctionCaller.f`.
     """
     if np.isscalar(arg):
+        from pytato.tags import ForceValueArgTag
         name = arg_id_to_name[kw,]
-        return pt.make_placeholder(name, (), np.dtype(type(arg)))
+        return pt.make_placeholder(name, (), np.dtype(type(arg)),
+                                   tags=frozenset({ForceValueArgTag()}))
     elif isinstance(arg, pt.Array):
         name = arg_id_to_name[kw,]
         # Transform the DAG to give metadata inference a chance to do its job
@@ -533,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
     for arg_id, arg in arg_id_to_arg.items():
         if np.isscalar(arg):
             if isinstance(actx, PytatoPyOpenCLArrayContext):
-                import pyopencl.array as cla
-                arg = cla.to_device(actx.queue, np.array(arg),
-                        allocator=actx.allocator)
+                # Scalar kernel args are passed as lp.ValueArgs
+                pass
             elif isinstance(actx, PytatoJAXArrayContext):
                 import jax
                 arg = jax.device_put(arg)
diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py
index a4050380688f68b7e9c406fe55c92d0cc88631c2..deee7405ec3542b04637dbb98cf4bbe1f213e23e 100644
--- a/test/test_pytato_arraycontext.py
+++ b/test/test_pytato_arraycontext.py
@@ -247,6 +247,33 @@ def test_transfer(actx_factory):
     # }}}
 
 
+def test_pass_args_compiled_func(actx_factory):
+    import numpy as np
+
+    import loopy as lp
+    import pyopencl as cl
+    import pyopencl.array
+    import pytato as pt
+
+    def twice(x, y, a):
+        return 2 * x * y * a
+
+    actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue)
+
+    dev_scalar = pt.make_data_wrapper(cl.array.to_device(actx.queue, np.float64(23)))
+
+    f = actx.compile(twice)
+
+    assert actx.to_numpy(f(99.0, np.float64(2.0), dev_scalar)) == 2*23*99*2
+
+    compiled_func, = f.program_cache.values()
+    ep = compiled_func.pytato_program.program.t_unit.default_entrypoint
+
+    assert isinstance(ep.arg_dict["_actx_in_0"], lp.ValueArg)
+    assert isinstance(ep.arg_dict["_actx_in_1"], lp.ValueArg)
+    assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1: