diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index d0466eeeaf5de6b215708d6203b45ba60c60f3e3..8a72d9aa41b7cf29b8f7e4248e879818c6bcdbf8 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -50,40 +50,81 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
     def __getattr__(self, name):
         return partial(rec_multimap_array_container, getattr(jnp, name))
 
+    # NOTE: the order of these follows the order in numpy docs
+    # NOTE: when adding a function here, also add it to `array_context.rst` docs!
+
+    # {{{ array creation routines
+
+    def ones_like(self, ary):
+        return self.full_like(ary, 1)
+
+    def full_like(self, ary, fill_value):
+        def _full_like(subary):
+            return jnp.full_like(ary, fill_value)
+
+        return self._new_like(ary, _full_like)
+
+    # }}}
+
+    # {{{ array manipulation routies
+
     def reshape(self, a, newshape, order="C"):
         return rec_map_array_container(
             lambda ary: jnp.reshape(ary, newshape, order=order),
             a)
 
-    def transpose(self, a, axes=None):
-        return rec_multimap_array_container(jnp.transpose, a, axes)
+    def ravel(self, a, order="C"):
+        """
+        .. warning::
 
-    def concatenate(self, arrays, axis=0):
-        return rec_multimap_array_container(jnp.concatenate, arrays, axis)
+            Since :func:`jax.numpy.reshape` does not support orders `A`` and
+            ``K``, in such cases we fallback to using ``order = C``.
+        """
+        if order in "AK":
+            from warnings import warn
+            warn(f"ravel with order='{order}' not supported by JAX,"
+                 " using order=C.")
+            order = "C"
 
-    def where(self, criterion, then, else_):
-        return rec_multimap_array_container(jnp.where, criterion, then, else_)
+        return rec_map_array_container(
+            lambda subary: jnp.ravel(subary, order=order), a)
 
-    def sum(self, a, axis=None, dtype=None):
-        return rec_map_reduce_array_container(sum,
-                                              partial(jnp.sum,
-                                                      axis=axis,
-                                                      dtype=dtype),
-                                              a)
+    def transpose(self, a, axes=None):
+        return rec_multimap_array_container(jnp.transpose, a, axes)
 
-    def min(self, a, axis=None):
-        return rec_map_reduce_array_container(
-                partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
+    def broadcast_to(self, array, shape):
+        return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
 
-    def max(self, a, axis=None):
-        return rec_map_reduce_array_container(
-                partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
+    def concatenate(self, arrays, axis=0):
+        return rec_multimap_array_container(jnp.concatenate, arrays, axis)
 
     def stack(self, arrays, axis=0):
         return rec_multimap_array_container(
             lambda *args: jnp.stack(arrays=args, axis=axis),
             *arrays)
 
+    # }}}
+
+    # {{{ linear algebra
+
+    def vdot(self, x, y, dtype=None):
+        from arraycontext import rec_multimap_reduce_array_container
+
+        def _rec_vdot(ary1, ary2):
+            if dtype not in [None, numpy.find_common_type((ary1.dtype,
+                                                           ary2.dtype),
+                                                          ())]:
+                raise NotImplementedError(f"{type(self)} cannot take dtype in"
+                                          " vdot.")
+
+            return jnp.vdot(ary1, ary2)
+
+        return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
+
+    # }}}
+
+    # {{{ logic functions
+
     def array_equal(self, a, b):
         actx = self._array_context
 
@@ -109,35 +150,33 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
         return rec_equal(a, b)
 
-    def ravel(self, a, order="C"):
-        """
-        .. warning::
+    # }}}
 
-            Since :func:`jax.numpy.reshape` does not support orders `A`` and
-            ``K``, in such cases we fallback to using ``order = C``.
-        """
-        if order in "AK":
-            from warnings import warn
-            warn(f"ravel with order='{order}' not supported by JAX,"
-                 " using order=C.")
-            order = "C"
+    # {{{ mathematical functions
+
+    def sum(self, a, axis=None, dtype=None):
+        return rec_map_reduce_array_container(
+            sum,
+            partial(jnp.sum, axis=axis, dtype=dtype),
+            a)
 
-        return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order),
-                                       a)
+    def amin(self, a, axis=None):
+        return rec_map_reduce_array_container(
+                partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
 
-    def vdot(self, x, y, dtype=None):
-        from arraycontext import rec_multimap_reduce_array_container
+    min = amin
 
-        def _rec_vdot(ary1, ary2):
-            if dtype not in [None, numpy.find_common_type((ary1.dtype,
-                                                           ary2.dtype),
-                                                          ())]:
-                raise NotImplementedError(f"{type(self)} cannot take dtype in"
-                                          " vdot.")
+    def amax(self, a, axis=None):
+        return rec_map_reduce_array_container(
+                partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
 
-            return jnp.vdot(ary1, ary2)
+    max = amax
 
-        return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
+    # }}}
 
-    def broadcast_to(self, array, shape):
-        return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
+    # {{{ sorting, searching and counting
+
+    def where(self, criterion, then, else_):
+        return rec_multimap_array_container(jnp.where, criterion, then, else_)
+
+    # }}}