From 838730666a5badb705ba979cadae7df6fda0bc35 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Wed, 19 Mar 2025 14:42:55 -0700
Subject: [PATCH] _get_f_placeholder_args: set ForceValueArgTag (#304)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* _get_f_placeholder_args: set ForceValueArgTag

* Update requirements.txt

* skip scalar arg handling in _args_to_device_buffers

* Revert changes to requirements.txt

* add a simple test

* Fix test

---------

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 arraycontext/impl/pytato/compile.py |  9 +++++----
 test/test_pytato_arraycontext.py    | 27 +++++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index e77c109..90449f0 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
     :attr:`BaseLazilyCompilingFunctionCaller.f`.
     """
     if np.isscalar(arg):
+        from pytato.tags import ForceValueArgTag
         name = arg_id_to_name[kw,]
-        return pt.make_placeholder(name, (), np.dtype(type(arg)))
+        return pt.make_placeholder(name, (), np.dtype(type(arg)),
+                                   tags=frozenset({ForceValueArgTag()}))
     elif isinstance(arg, pt.Array):
         name = arg_id_to_name[kw,]
         # Transform the DAG to give metadata inference a chance to do its job
@@ -533,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
     for arg_id, arg in arg_id_to_arg.items():
         if np.isscalar(arg):
             if isinstance(actx, PytatoPyOpenCLArrayContext):
-                import pyopencl.array as cla
-                arg = cla.to_device(actx.queue, np.array(arg),
-                        allocator=actx.allocator)
+                # Scalar kernel args are passed as lp.ValueArgs
+                pass
             elif isinstance(actx, PytatoJAXArrayContext):
                 import jax
                 arg = jax.device_put(arg)
diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py
index a405038..deee740 100644
--- a/test/test_pytato_arraycontext.py
+++ b/test/test_pytato_arraycontext.py
@@ -247,6 +247,33 @@ def test_transfer(actx_factory):
     # }}}
 
 
+def test_pass_args_compiled_func(actx_factory):
+    import numpy as np
+
+    import loopy as lp
+    import pyopencl as cl
+    import pyopencl.array
+    import pytato as pt
+
+    def twice(x, y, a):
+        return 2 * x * y * a
+
+    actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue)
+
+    dev_scalar = pt.make_data_wrapper(cl.array.to_device(actx.queue, np.float64(23)))
+
+    f = actx.compile(twice)
+
+    assert actx.to_numpy(f(99.0, np.float64(2.0), dev_scalar)) == 2*23*99*2
+
+    compiled_func, = f.program_cache.values()
+    ep = compiled_func.pytato_program.program.t_unit.default_entrypoint
+
+    assert isinstance(ep.arg_dict["_actx_in_0"], lp.ValueArg)
+    assert isinstance(ep.arg_dict["_actx_in_1"], lp.ValueArg)
+    assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab