From 838730666a5badb705ba979cadae7df6fda0bc35 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Wed, 19 Mar 2025 14:42:55 -0700 Subject: [PATCH] _get_f_placeholder_args: set ForceValueArgTag (#304) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * _get_f_placeholder_args: set ForceValueArgTag * Update requirements.txt * skip scalar arg handling in _args_to_device_buffers * Revert changes to requirements.txt * add a simple test * Fix test --------- Co-authored-by: Andreas Klöckner <inform@tiker.net> --- arraycontext/impl/pytato/compile.py | 9 +++++---- test/test_pytato_arraycontext.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e77c109..90449f0 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): :attr:`BaseLazilyCompilingFunctionCaller.f`. """ if np.isscalar(arg): + from pytato.tags import ForceValueArgTag name = arg_id_to_name[kw,] - return pt.make_placeholder(name, (), np.dtype(type(arg))) + return pt.make_placeholder(name, (), np.dtype(type(arg)), + tags=frozenset({ForceValueArgTag()})) elif isinstance(arg, pt.Array): name = arg_id_to_name[kw,] # Transform the DAG to give metadata inference a chance to do its job @@ -533,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): for arg_id, arg in arg_id_to_arg.items(): if np.isscalar(arg): if isinstance(actx, PytatoPyOpenCLArrayContext): - import pyopencl.array as cla - arg = cla.to_device(actx.queue, np.array(arg), - allocator=actx.allocator) + # Scalar kernel args are passed as lp.ValueArgs + pass elif isinstance(actx, PytatoJAXArrayContext): import jax arg = jax.device_put(arg) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index a405038..deee740 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -247,6 +247,33 @@ def test_transfer(actx_factory): # }}} +def test_pass_args_compiled_func(actx_factory): + import numpy as np + + import loopy as lp + import pyopencl as cl + import pyopencl.array + import pytato as pt + + def twice(x, y, a): + return 2 * x * y * a + + actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue) + + dev_scalar = pt.make_data_wrapper(cl.array.to_device(actx.queue, np.float64(23))) + + f = actx.compile(twice) + + assert actx.to_numpy(f(99.0, np.float64(2.0), dev_scalar)) == 2*23*99*2 + + compiled_func, = f.program_cache.values() + ep = compiled_func.pytato_program.program.t_unit.default_entrypoint + + assert isinstance(ep.arg_dict["_actx_in_0"], lp.ValueArg) + assert isinstance(ep.arg_dict["_actx_in_1"], lp.ValueArg) + assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab