Skip to content
Snippets Groups Projects
Unverified Commit 83873066 authored by Matthias Diener's avatar Matthias Diener Committed by GitHub
Browse files

_get_f_placeholder_args: set ForceValueArgTag (#304)


* _get_f_placeholder_args: set ForceValueArgTag

* Update requirements.txt

* skip scalar arg handling in _args_to_device_buffers

* Revert changes to requirements.txt

* add a simple test

* Fix test

---------

Co-authored-by: default avatarAndreas Klöckner <inform@tiker.net>
parent 029026cc
No related branches found
No related tags found
No related merge requests found
Pipeline #653658 failed
...@@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): ...@@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
:attr:`BaseLazilyCompilingFunctionCaller.f`. :attr:`BaseLazilyCompilingFunctionCaller.f`.
""" """
if np.isscalar(arg): if np.isscalar(arg):
from pytato.tags import ForceValueArgTag
name = arg_id_to_name[kw,] name = arg_id_to_name[kw,]
return pt.make_placeholder(name, (), np.dtype(type(arg))) return pt.make_placeholder(name, (), np.dtype(type(arg)),
tags=frozenset({ForceValueArgTag()}))
elif isinstance(arg, pt.Array): elif isinstance(arg, pt.Array):
name = arg_id_to_name[kw,] name = arg_id_to_name[kw,]
# Transform the DAG to give metadata inference a chance to do its job # Transform the DAG to give metadata inference a chance to do its job
...@@ -533,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): ...@@ -533,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
for arg_id, arg in arg_id_to_arg.items(): for arg_id, arg in arg_id_to_arg.items():
if np.isscalar(arg): if np.isscalar(arg):
if isinstance(actx, PytatoPyOpenCLArrayContext): if isinstance(actx, PytatoPyOpenCLArrayContext):
import pyopencl.array as cla # Scalar kernel args are passed as lp.ValueArgs
arg = cla.to_device(actx.queue, np.array(arg), pass
allocator=actx.allocator)
elif isinstance(actx, PytatoJAXArrayContext): elif isinstance(actx, PytatoJAXArrayContext):
import jax import jax
arg = jax.device_put(arg) arg = jax.device_put(arg)
......
...@@ -247,6 +247,33 @@ def test_transfer(actx_factory): ...@@ -247,6 +247,33 @@ def test_transfer(actx_factory):
# }}} # }}}
def test_pass_args_compiled_func(actx_factory):
import numpy as np
import loopy as lp
import pyopencl as cl
import pyopencl.array
import pytato as pt
def twice(x, y, a):
return 2 * x * y * a
actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue)
dev_scalar = pt.make_data_wrapper(cl.array.to_device(actx.queue, np.float64(23)))
f = actx.compile(twice)
assert actx.to_numpy(f(99.0, np.float64(2.0), dev_scalar)) == 2*23*99*2
compiled_func, = f.program_cache.values()
ep = compiled_func.pytato_program.program.t_unit.default_entrypoint
assert isinstance(ep.arg_dict["_actx_in_0"], lp.ValueArg)
assert isinstance(ep.arg_dict["_actx_in_1"], lp.ValueArg)
assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
if len(sys.argv) > 1: if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment