From dc8109021d1c76016457fe91cfa203bd7e2f3cea Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 19 Sep 2022 21:52:07 -0500
Subject: [PATCH] pytatoactx.compile: support multi-dimensional obj arrays

---
 arraycontext/impl/pytato/compile.py |  8 ++++----
 test/test_arraycontext.py           | 25 +++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index ac4c01d..20455cd 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -354,8 +354,7 @@ class BaseLazilyCompilingFunctionCaller:
                 f" but an instance of '{output_template.__class__}' instead.")
 
         def _as_dict_of_named_arrays(keys, ary):
-            name = "_pt_out_" + "_".join(str(key)
-                                         for key in keys)
+            name = "_pt_out_" + _ary_container_key_stringifier(keys)
             output_id_to_name_in_program[keys] = name
             dict_of_named_arrays[name] = ary
             return ary
@@ -606,7 +605,7 @@ class CompiledFunction(abc.ABC):
 # }}}
 
 
-# {{{ copmiled pyopencl function
+# {{{ compiled pyopencl function
 
 @dataclass(frozen=True)
 class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
@@ -698,7 +697,8 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
 # }}}
 
 
-# {{{ comiled jax function
+# {{{ compiled jax function
+
 @dataclass(frozen=True)
 class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
     """
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 842d108..d27e3db 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -1278,6 +1278,31 @@ def test_actx_compile_kwargs(actx_factory):
     np.testing.assert_allclose(result.u, -3.14*v_y)
     np.testing.assert_allclose(result.v, 3.14*v_x)
 
+
+def test_actx_compile_with_tuple_output_keys(actx_factory):
+    # arraycontext.git<=3c9aee68 would fail due to a bug in output
+    # key stringification logic.
+    from arraycontext import (to_numpy, from_numpy)
+    actx = actx_factory()
+
+    def my_rhs(scale, vel):
+        result = np.empty((1, 1), dtype=object)
+        result[0, 0] = scale_and_orthogonalize(scale, vel)
+        return result
+
+    compiled_rhs = actx.compile(my_rhs)
+
+    v_x = np.random.rand(10)
+    v_y = np.random.rand(10)
+
+    vel = from_numpy(Velocity2D(v_x, v_y, actx), actx)
+
+    scaled_speed = compiled_rhs(3.14, vel=vel)
+
+    result = to_numpy(scaled_speed, actx)[0, 0]
+    np.testing.assert_allclose(result.u, -3.14*v_y)
+    np.testing.assert_allclose(result.v, 3.14*v_x)
+
 # }}}
 
 
-- 
GitLab