From dc8109021d1c76016457fe91cfa203bd7e2f3cea Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 19 Sep 2022 21:52:07 -0500 Subject: [PATCH] pytatoactx.compile: support multi-dimensional obj arrays --- arraycontext/impl/pytato/compile.py | 8 ++++---- test/test_arraycontext.py | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index ac4c01d..20455cd 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -354,8 +354,7 @@ class BaseLazilyCompilingFunctionCaller: f" but an instance of '{output_template.__class__}' instead.") def _as_dict_of_named_arrays(keys, ary): - name = "_pt_out_" + "_".join(str(key) - for key in keys) + name = "_pt_out_" + _ary_container_key_stringifier(keys) output_id_to_name_in_program[keys] = name dict_of_named_arrays[name] = ary return ary @@ -606,7 +605,7 @@ class CompiledFunction(abc.ABC): # }}} -# {{{ copmiled pyopencl function +# {{{ compiled pyopencl function @dataclass(frozen=True) class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): @@ -698,7 +697,8 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): # }}} -# {{{ comiled jax function +# {{{ compiled jax function + @dataclass(frozen=True) class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): """ diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108..d27e3db 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1278,6 +1278,31 @@ def test_actx_compile_kwargs(actx_factory): np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) + +def test_actx_compile_with_tuple_output_keys(actx_factory): + # arraycontext.git<=3c9aee68 would fail due to a bug in output + # key stringification logic. + from arraycontext import (to_numpy, from_numpy) + actx = actx_factory() + + def my_rhs(scale, vel): + result = np.empty((1, 1), dtype=object) + result[0, 0] = scale_and_orthogonalize(scale, vel) + return result + + compiled_rhs = actx.compile(my_rhs) + + v_x = np.random.rand(10) + v_y = np.random.rand(10) + + vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) + + scaled_speed = compiled_rhs(3.14, vel=vel) + + result = to_numpy(scaled_speed, actx)[0, 0] + np.testing.assert_allclose(result.u, -3.14*v_y) + np.testing.assert_allclose(result.v, 3.14*v_x) + # }}} -- GitLab