From 4341b9d63df0eb3ce8c5cb314688a72de5b01236 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Wed, 28 Jul 2021 18:58:08 -0500 Subject: [PATCH] PytatoPyOpenCLArrayContext.compile: add support for keyword arguments --- arraycontext/impl/pytato/compile.py | 27 +++++++++++++++++---------- test/test_arraycontext.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 781947f..2c1026c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -39,6 +39,7 @@ from pyrsistent import pmap, PMap import pyopencl.array as cla import pytato as pt +import itertools # {{{ helper classes: AbstractInputDescriptor @@ -90,7 +91,8 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: return "_".join(_rec_str(key) for key in keys) -def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] +def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], + kwargs: Mapping[str, Any] ) -> "Tuple[PMap[Tuple[Any, ...],\ Any],\ PMap[Tuple[Any, ...],\ @@ -106,14 +108,15 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} - for iarg, arg in enumerate(args): + for kw, arg in itertools.chain(enumerate(args), + kwargs.items()): if np.isscalar(arg): - arg_id = (iarg,) + arg_id = (kw,) arg_id_to_arg[arg_id] = arg arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg))) elif is_array_container(arg): def id_collector(keys, ary): - arg_id = (iarg,) + keys + arg_id = (kw,) + keys arg_id_to_arg[arg_id] = ary arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(ary.dtype), ary.shape) @@ -128,18 +131,18 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] return pmap(arg_id_to_arg), pmap(arg_id_to_descr) -def _get_f_placeholder_args(arg, iarg, arg_id_to_name): +def _get_f_placeholder_args(arg, kw, arg_id_to_name): """ Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the placeholder version of an argument to :attr:`LazilyCompilingFunctionCaller.f`. """ if np.isscalar(arg): - name = arg_id_to_name[(iarg,)] + name = arg_id_to_name[(kw,)] return pt.make_placeholder(name, (), np.dtype(type(arg))) elif is_array_container(arg): def _rec_to_placeholder(keys, ary): - name = arg_id_to_name[(iarg,) + keys] + name = arg_id_to_name[(kw,) + keys] return pt.make_placeholder(name, ary.shape, ary.dtype) return rec_keyed_map_array_container(_rec_to_placeholder, arg) @@ -167,7 +170,7 @@ class LazilyCompilingFunctionCaller: program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", "CompiledFunction"] = field(default_factory=lambda: {}) - def __call__(self, *args: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s function application on *args*. @@ -178,7 +181,9 @@ class LazilyCompilingFunctionCaller: The intermediary pytato DAG for *args* is memoized in *self*. """ from pytato.target.loopy import BoundPyOpenCLProgram - arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args) + + arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( + args, kwargs) try: compiled_f = self.program_cache[arg_id_to_descr] @@ -198,7 +203,9 @@ class LazilyCompilingFunctionCaller: for arg_id in arg_id_to_arg} outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map) - for iarg, arg in enumerate(args)]) + for iarg, arg in enumerate(args)], + **{kw: _get_f_placeholder_args(arg, kw, input_naming_map) + for kw, arg in kwargs.items()}) if not is_array_container(outputs): # TODO: We could possibly just short-circuit this interface if the diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 668e320..9e855b0 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -937,6 +937,24 @@ def test_actx_compile_python_scalar(actx_factory): np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) + +def test_actx_compile_kwargs(actx_factory): + from arraycontext import (to_numpy, from_numpy) + actx = actx_factory() + + compiled_rhs = actx.compile(scale_and_orthogonalize) + + v_x = np.random.rand(10) + v_y = np.random.rand(10) + + vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) + + scaled_speed = compiled_rhs(3.14, vel=vel) + + result = to_numpy(scaled_speed, actx) + np.testing.assert_allclose(result.u, -3.14*v_y) + np.testing.assert_allclose(result.v, 3.14*v_x) + # }}} -- GitLab