diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index aa6de375076384afbc8f16e04f195327e5db824a..07c154464c1e06d475fe88a89078f83bf5328372 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -256,7 +256,8 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
 
 def rec_map_array_container(
         f: Callable[[Any], Any],
-        ary: ArrayOrContainerT) -> ArrayOrContainerT:
+        ary: ArrayOrContainerT,
+        leaf_class: Optional[type] = None) -> ArrayOrContainerT:
     r"""Applies *f* recursively to an :class:`ArrayContainer`.
 
     For a non-recursive version see :func:`map_array_container`.
@@ -264,18 +265,32 @@ def rec_map_array_container(
     :param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
         or an instance of a base array type.
     """
-    return _map_array_container_impl(f, ary, recursive=True)
+    return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)
 
 
 def mapped_over_array_containers(
-        f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
+        f: Optional[Callable[[Any], Any]] = None,
+        leaf_class: Optional[type] = None) -> Union[
+            Callable[[ArrayOrContainerT], ArrayOrContainerT],
+            Callable[
+                [Callable[[Any], Any]],
+                Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
     """Decorator around :func:`rec_map_array_container`."""
-    wrapper = partial(rec_map_array_container, f)
-    update_wrapper(wrapper, f)
-    return wrapper
+    def decorator(g: Callable[[Any], Any]) -> Callable[
+            [ArrayOrContainerT], ArrayOrContainerT]:
+        wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
+        update_wrapper(wrapper, g)
+        return wrapper
+    if f is not None:
+        return decorator(f)
+    else:
+        return decorator
 
 
-def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
+def rec_multimap_array_container(
+        f: Callable[..., Any],
+        *args: Any,
+        leaf_class: Optional[type] = None) -> Any:
     r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
 
     For a non-recursive version see :func:`multimap_array_container`.
@@ -283,19 +298,28 @@ def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
     :param args: all :class:`ArrayContainer` arguments must be of the same
         type and with the same structure (same number of components, etc.).
     """
-    return _multimap_array_container_impl(f, *args, recursive=True)
+    return _multimap_array_container_impl(
+        f, *args, leaf_cls=leaf_class, recursive=True)
 
 
 def multimapped_over_array_containers(
-        f: Callable[..., Any]) -> Callable[..., Any]:
+        f: Optional[Callable[..., Any]] = None,
+        leaf_class: Optional[type] = None) -> Union[
+            Callable[..., Any],
+            Callable[[Callable[..., Any]], Callable[..., Any]]]:
     """Decorator around :func:`rec_multimap_array_container`."""
-    # can't use functools.partial, because its result is insufficiently
-    # function-y to be used as a method definition.
-    def wrapper(*args: Any) -> Any:
-        return rec_multimap_array_container(f, *args)
+    def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
+        # can't use functools.partial, because its result is insufficiently
+        # function-y to be used as a method definition.
+        def wrapper(*args: Any) -> Any:
+            return rec_multimap_array_container(g, *args, leaf_class=leaf_class)
+        update_wrapper(wrapper, g)
+        return wrapper
+    if f is not None:
+        return decorator(f)
+    else:
+        return decorator
 
-    update_wrapper(wrapper, f)
-    return wrapper
 
 # }}}
 
@@ -401,7 +425,8 @@ def multimap_reduce_array_container(
 def rec_map_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[[Any], Any],
-        ary: ArrayOrContainerT) -> "DeviceArray":
+        ary: ArrayOrContainerT,
+        leaf_class: Optional[type] = None) -> "DeviceArray":
     """Perform a map-reduce over array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of *ary*
@@ -440,14 +465,17 @@ def rec_map_reduce_array_container(
         or any other such traversal.
     """
     def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
-        try:
-            iterable = serialize_container(_ary)
-        except NotAnArrayContainerError:
+        if type(_ary) is leaf_class:
             return map_func(_ary)
         else:
-            return reduce_func([
-                rec(subary) for _, subary in iterable
-                ])
+            try:
+                iterable = serialize_container(_ary)
+            except NotAnArrayContainerError:
+                return map_func(_ary)
+            else:
+                return reduce_func([
+                    rec(subary) for _, subary in iterable
+                    ])
 
     return rec(ary)
 
@@ -455,7 +483,8 @@ def rec_map_reduce_array_container(
 def rec_multimap_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[..., Any],
-        *args: Any) -> "DeviceArray":
+        *args: Any,
+        leaf_class: Optional[type] = None) -> "DeviceArray":
     r"""Perform a map-reduce over multiple array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of any
@@ -478,7 +507,7 @@ def rec_multimap_reduce_array_container(
 
     return _multimap_array_container_impl(
         map_func, *args,
-        reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True)
+        reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True)
 
 # }}}
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 0fe5480db54dadccab767306a2562ea5720d8cb1..ac7ebbf5798c232cc5ce8aa8bb4b382b98112d46 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
         assert result is not None
 
 
+def test_container_map(actx_factory):
+    actx = actx_factory()
+    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
+            _get_test_containers(actx)
+
+    # {{{ check
+
+    def _check_allclose(f, arg1, arg2, atol=2.0e-14):
+        from arraycontext import NotAnArrayContainerError
+        try:
+            arg1_iterable = serialize_container(arg1)
+            arg2_iterable = serialize_container(arg2)
+        except NotAnArrayContainerError:
+            assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
+        else:
+            arg1_subarrays = [
+                subarray for _, subarray in arg1_iterable]
+            arg2_subarrays = [
+                subarray for _, subarray in arg2_iterable]
+            for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
+                _check_allclose(f, subarray1, subarray2)
+
+    def func(x):
+        return x + 1
+
+    from arraycontext import rec_map_array_container
+    result = rec_map_array_container(func, 1)
+    assert result == 2
+
+    for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
+        result = rec_map_array_container(func, ary)
+        _check_allclose(func, ary, result)
+
+    from arraycontext import mapped_over_array_containers
+
+    @mapped_over_array_containers
+    def mapped_func(x):
+        return func(x)
+
+    for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
+        result = mapped_func(ary)
+        _check_allclose(func, ary, result)
+
+    @mapped_over_array_containers(leaf_class=DOFArray)
+    def check_leaf(x):
+        assert isinstance(x, DOFArray)
+
+    for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
+        check_leaf(ary)
+
+    # }}}
+
+
 def test_container_multimap(actx_factory):
     actx = actx_factory()
     ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
@@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
     # {{{ check
 
     def _check_allclose(f, arg1, arg2, atol=2.0e-14):
-        assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
+        from arraycontext import NotAnArrayContainerError
+        try:
+            arg1_iterable = serialize_container(arg1)
+            arg2_iterable = serialize_container(arg2)
+        except NotAnArrayContainerError:
+            assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
+        else:
+            arg1_subarrays = [
+                subarray for _, subarray in arg1_iterable]
+            arg2_subarrays = [
+                subarray for _, subarray in arg2_iterable]
+            for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
+                _check_allclose(f, subarray1, subarray2)
 
     def func_all_scalar(x, y):
         return x + y
@@ -779,17 +844,30 @@ def test_container_multimap(actx_factory):
     result = rec_multimap_array_container(func_all_scalar, 1, 2)
     assert result == 3
 
-    from functools import partial
     for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
         result = rec_multimap_array_container(func_first_scalar, 1, ary)
-        rec_multimap_array_container(
-                partial(_check_allclose, lambda x: 1 + x),
-                ary, result)
+        _check_allclose(lambda x: 1 + x, ary, result)
 
         result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary)
-        rec_multimap_array_container(
-                partial(_check_allclose, lambda x: 4 * x),
-                ary, result)
+        _check_allclose(lambda x: 4 * x, ary, result)
+
+    from arraycontext import multimapped_over_array_containers
+
+    @multimapped_over_array_containers
+    def mapped_func(a, subary1, b, subary2):
+        return func_multiple_scalar(a, subary1, b, subary2)
+
+    for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
+        result = mapped_func(2, ary, 2, ary)
+        _check_allclose(lambda x: 4 * x, ary, result)
+
+    @multimapped_over_array_containers(leaf_class=DOFArray)
+    def check_leaf(a, subary1, b, subary2):
+        assert isinstance(subary1, DOFArray)
+        assert isinstance(subary2, DOFArray)
+
+    for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
+        check_leaf(2, ary, 2, ary)
 
     with pytest.raises(AssertionError):
         rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)