diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py
index a49a81e70237c231c2644c24f845442075e17b33..a7cea85dd0c18454b24d4df7a310c9ad9a7f8c84 100644
--- a/arraycontext/impl/pytato.py
+++ b/arraycontext/impl/pytato.py
@@ -152,71 +152,63 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
 
 class PytatoCompiledOperator:
-    def __init__(self, actx, pytato_program, input_spec, output_spec):
+    def __init__(self, actx, pytato_program, input_id_to_name_in_program,
+                 output_id_to_name_in_program, output_template):
         self.actx = actx
         self.pytato_program = pytato_program
-        self.input_spec = input_spec
-        self.output_spec = output_spec
+        self.input_id_to_name_in_program = input_id_to_name_in_program
+        self.output_id_to_name_in_program = output_id_to_name_in_program
+        self.output_template = output_template
 
     def __call__(self, *args):
         import pytato as pt
         import pyopencl.array as cla
-        from arraycontext.impl import _is_meshmode_dofarray
-        from pytools.obj_array import flat_obj_array
-
-        updated_kwargs = {}
-
-        def from_obj_array_to_input_dict(array, pos):
-            input_dict = {}
-            for i in range(len(self.input_spec[pos])):
-                for j in range(self.input_spec[pos][i]):
-                    ary = array[i][j]
-                    arg_name = f"_msh_inp_{pos}_{i}_{j}"
-                    if arg_name not in (
-                            self.pytato_program.program["_pt_kernel"].arg_dict):
-                        continue
+        from arraycontext import (is_array_container,
+                                  rec_keyed_map_array_container)
+
+        input_kwargs_to_loopy = {}
+
+        # {{{ extract loopy arguments execute the program
+
+        for pos, arg in enumerate(args):
+            if isinstance(arg, np.number):
+                input_kwargs_to_loopy[self.input_id_to_name_in_program[pos]] = (
+                    arg)
+            elif is_array_container(arg):
+                def _extract_lpy_kwargs(keys, ary):
                     if isinstance(ary, pt.array.DataWrapper):
-                        input_dict[arg_name] = ary.data
+                        processed_ary = ary.data
                     elif isinstance(ary, cla.Array):
-                        input_dict[arg_name] = ary
+                        processed_ary = ary
                     elif isinstance(ary, pt.Array):
-                        input_dict[arg_name] = self.actx.freeze(
-                                ary).with_queue(self.actx.queue)
+                        processed_ary = (self.actx.freeze(ary)
+                                         .with_queue(self.actx.queue))
                     else:
-                        raise TypeError("Expect pt.DataWrapper or CL-array, got "
+                        raise TypeError("Expect pt.Array or CL-array, got "
                                 f"{type(ary)}")
 
-            return input_dict
+                    input_kwargs_to_loopy[
+                        self.input_id_to_name_in_program[(pos,)
+                                                         + keys]] = processed_ary
+                    return ary
 
-        def from_return_dict_to_obj_array(return_dict):
-            from meshmode.dof_array import DOFArray  # pylint: disable=import-error
-            return flat_obj_array([DOFArray.from_list(self.actx,
-                [self.actx.thaw(return_dict[f"_msh_out_{i}_{j}"])
-                 for j in range(self.output_spec[i])])
-                for i in range(len(self.output_spec))])
-
-        for iarg, arg in enumerate(args):
-            if isinstance(arg, np.number):
-                arg_name = f"_msh_inp_{iarg}"
-                if arg_name not in (
-                        self.pytato_program.program["_pt_kernel"].arg_dict):
-                    continue
-
-                updated_kwargs[arg_name] = cla.to_device(self.actx.queue,
-                        np.array(arg))
-            elif isinstance(arg, np.ndarray) and all(_is_meshmode_dofarray(el)
-                                                     for el in arg):
-                updated_kwargs.update(from_obj_array_to_input_dict(arg, iarg))
+                rec_keyed_map_array_container(_extract_lpy_kwargs, arg)
             else:
-                raise NotImplementedError("PytatoCompiledOperator cannot handle"
-                                          f" '{type(arg)}'s")
+                raise NotImplementedError(type(arg))
 
         evt, out_dict = self.pytato_program(queue=self.actx.queue,
                                             allocator=self.actx.allocator,
-                                            **updated_kwargs)
+                                            **input_kwargs_to_loopy)
+
         evt.wait()
 
-        return from_return_dict_to_obj_array(out_dict)
+        # }}}
+
+        def to_output_template(keys, _):
+            return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]])
+
+        return rec_keyed_map_array_container(to_output_template,
+                                             self.output_template)
 
 
 class PytatoArrayContext(ArrayContext):
@@ -308,70 +300,66 @@ class PytatoArrayContext(ArrayContext):
     # }}}
 
     def compile(self, f: Callable[[Any], Any],
-            inputs_like: Tuple[Union[Number, np.ndarray], ...]) -> Callable[
-                ..., Any]:
-        from pytools.obj_array import flat_obj_array
-        from arraycontext.impl import _is_meshmode_dofarray
-        from meshmode.dof_array import DOFArray  # pylint: disable=import-error
+                inputs_like: Tuple[Union[Number, np.ndarray], ...]
+                ) -> Callable[..., Any]:
+        from arraycontext import (rec_keyed_map_array_container,
+                                  is_array_container)
         import pytato as pt
 
-        def make_placeholder_like(input_like, pos):
+        dict_of_named_arrays = {}
+        output_naming_map = {}
+        input_naming_map = {}
+
+        def to_placeholder(input_like, pos):
             if isinstance(input_like, np.number):
-                return pt.make_placeholder((), input_like.dtype,
-                                           f"_msh_inp_{pos}")
-            elif isinstance(input_like, np.ndarray) and all(_is_meshmode_dofarray(e)
-                                                            for e in input_like):
-                return flat_obj_array([DOFArray.from_list(self,
-                    [pt.make_placeholder(grp_ary.shape,
-                                         grp_ary.dtype, f"_msh_inp_{pos}_{i}_{j}")
-                     for j, grp_ary in enumerate(dof_ary)])
-                    for i, dof_ary in enumerate(input_like)])
-
-            raise NotImplementedError(f"Unknown input type '{type(input_like)}'.")
-
-        def as_dict_of_named_arrays(fields_obj_ary):
-            dict_of_named_arrays = {}
-            # output_spec: a list of length #fields; ith-entry denotes #groups in
-            # ith-field
-            output_spec = []
-            for i, field in enumerate(fields_obj_ary):
-                output_spec.append(len(field))
-                for j, grp in enumerate(field):
-                    dict_of_named_arrays[f"_msh_out_{i}_{j}"] = grp
-
-            return pt.make_dict_of_named_arrays(dict_of_named_arrays), output_spec
-
-        outputs = f(*[make_placeholder_like(el, iel)
+                name = f"_pt_in_{pos}"
+                input_naming_map[(pos, )] = name
+                return pt.make_placeholder((), input_like.dtype, name)
+            elif is_array_container(input_like):
+                def _rec_to_placeholder(keys, ary):
+                    name = f"_pt_in_{pos}_" + "_".join(str(key)
+                                                       for key in keys)
+                    input_naming_map[(pos,) + keys] = name
+                    return pt.make_placeholder(ary.shape, ary.dtype,
+                                               name)
+                return rec_keyed_map_array_container(_rec_to_placeholder,
+                                                     input_like)
+            else:
+                raise NotImplementedError("Unknown input type "
+                                          f"'{type(input_like)}'.")
+
+        outputs = f(*[to_placeholder(el, iel)
                       for iel, el in enumerate(inputs_like)])
 
-        if not (isinstance(outputs, np.ndarray)
-                and all(_is_meshmode_dofarray(e)
-                        for e in outputs)):
-            raise TypeError("Can only pass in functions that return numpy"
-                            " array of DOFArrays.")
+        if not is_array_container(outputs):
+            # TODO: We could possibly just short-circuit this interface if the
+            # returned type is a scalar. Not sure if it's worth it though.
+            raise ValueError("Function to be compiled did not return an array"
+                             " container.")
+
+        def _as_dict_of_named_arrays(keys, ary):
+            name = "_pt_out_" + "_".join(str(key)
+                                          for key in keys)
+            output_naming_map[keys] = name
+            dict_of_named_arrays[name] = ary
+            return ary
 
-        output_dict_of_named_arrays, output_spec = as_dict_of_named_arrays(outputs)
+        rec_keyed_map_array_container(_as_dict_of_named_arrays,
+                                      outputs)
 
-        pytato_program = pt.generate_loopy(output_dict_of_named_arrays,
+        pytato_program = pt.generate_loopy(dict_of_named_arrays,
                                            options={"return_dict": True},
                                            cl_device=self.queue.device)
 
         if False:
-            from time import time
-            start = time()
             # transforming leads to compile-time slow downs (turning off for now)
-            pytato_program.program = self.transform_loopy_program(
-                    pytato_program.program)
-            end = time()
-            print(f"Transforming took {end-start} secs")
+            pytato_program.program = self.transform_loopy_program(pytato_program
+                                                                  .program)
 
         return PytatoCompiledOperator(self, pytato_program,
-                                      [[len(arg) for arg in input_like]
-                                       if isinstance(input_like, np.ndarray)
-                                       else []
-
-                                       for input_like in inputs_like],
-                                      output_spec)
+                                      input_naming_map,
+                                      output_naming_map,
+                                      output_template=outputs)
 
     def transform_loopy_program(self, prg):
         from loopy.translation_unit import for_each_kernel