From 4341b9d63df0eb3ce8c5cb314688a72de5b01236 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Wed, 28 Jul 2021 18:58:08 -0500
Subject: [PATCH] PytatoPyOpenCLArrayContext.compile: add support for keyword
 arguments

---
 arraycontext/impl/pytato/compile.py | 27 +++++++++++++++++----------
 test/test_arraycontext.py           | 18 ++++++++++++++++++
 2 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 781947f..2c1026c 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -39,6 +39,7 @@ from pyrsistent import pmap, PMap
 
 import pyopencl.array as cla
 import pytato as pt
+import itertools
 
 
 # {{{ helper classes: AbstractInputDescriptor
@@ -90,7 +91,8 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
     return "_".join(_rec_str(key) for key in keys)
 
 
-def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...]
+def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
+                                           kwargs: Mapping[str, Any]
                                            ) -> "Tuple[PMap[Tuple[Any, ...],\
                                                             Any],\
                                                        PMap[Tuple[Any, ...],\
@@ -106,14 +108,15 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...]
     arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {}
     arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {}
 
-    for iarg, arg in enumerate(args):
+    for kw, arg in itertools.chain(enumerate(args),
+                                   kwargs.items()):
         if np.isscalar(arg):
-            arg_id = (iarg,)
+            arg_id = (kw,)
             arg_id_to_arg[arg_id] = arg
             arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
         elif is_array_container(arg):
             def id_collector(keys, ary):
-                arg_id = (iarg,) + keys
+                arg_id = (kw,) + keys
                 arg_id_to_arg[arg_id] = ary
                 arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(ary.dtype),
                                                               ary.shape)
@@ -128,18 +131,18 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...]
     return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
 
 
-def _get_f_placeholder_args(arg, iarg, arg_id_to_name):
+def _get_f_placeholder_args(arg, kw, arg_id_to_name):
     """
     Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the
     placeholder version of an argument to
     :attr:`LazilyCompilingFunctionCaller.f`.
     """
     if np.isscalar(arg):
-        name = arg_id_to_name[(iarg,)]
+        name = arg_id_to_name[(kw,)]
         return pt.make_placeholder(name, (), np.dtype(type(arg)))
     elif is_array_container(arg):
         def _rec_to_placeholder(keys, ary):
-            name = arg_id_to_name[(iarg,) + keys]
+            name = arg_id_to_name[(kw,) + keys]
             return pt.make_placeholder(name, ary.shape, ary.dtype)
         return rec_keyed_map_array_container(_rec_to_placeholder,
                                                 arg)
@@ -167,7 +170,7 @@ class LazilyCompilingFunctionCaller:
     program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]",
                         "CompiledFunction"] = field(default_factory=lambda: {})
 
-    def __call__(self, *args: Any) -> Any:
+    def __call__(self, *args: Any, **kwargs: Any) -> Any:
         """
         Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s
         function application on *args*.
@@ -178,7 +181,9 @@ class LazilyCompilingFunctionCaller:
         The intermediary pytato DAG for *args* is memoized in *self*.
         """
         from pytato.target.loopy import BoundPyOpenCLProgram
-        arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args)
+
+        arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
+            args, kwargs)
 
         try:
             compiled_f = self.program_cache[arg_id_to_descr]
@@ -198,7 +203,9 @@ class LazilyCompilingFunctionCaller:
             for arg_id in arg_id_to_arg}
 
         outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map)
-                           for iarg, arg in enumerate(args)])
+                           for iarg, arg in enumerate(args)],
+                         **{kw: _get_f_placeholder_args(arg, kw, input_naming_map)
+                            for kw, arg in kwargs.items()})
 
         if not is_array_container(outputs):
             # TODO: We could possibly just short-circuit this interface if the
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 668e320..9e855b0 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -937,6 +937,24 @@ def test_actx_compile_python_scalar(actx_factory):
     np.testing.assert_allclose(result.u, -3.14*v_y)
     np.testing.assert_allclose(result.v, 3.14*v_x)
 
+
+def test_actx_compile_kwargs(actx_factory):
+    from arraycontext import (to_numpy, from_numpy)
+    actx = actx_factory()
+
+    compiled_rhs = actx.compile(scale_and_orthogonalize)
+
+    v_x = np.random.rand(10)
+    v_y = np.random.rand(10)
+
+    vel = from_numpy(Velocity2D(v_x, v_y, actx), actx)
+
+    scaled_speed = compiled_rhs(3.14, vel=vel)
+
+    result = to_numpy(scaled_speed, actx)
+    np.testing.assert_allclose(result.u, -3.14*v_y)
+    np.testing.assert_allclose(result.v, 3.14*v_x)
+
 # }}}
 
 
-- 
GitLab