diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index e8f34d45be656c8dfa96b1cc83ab52655597556c..d9861d61591c155ae33bb0d6d4b25b8eb5da4692 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -64,9 +64,13 @@ from .container.traversal import (
         outer)
 
 from .impl.pyopencl import PyOpenCLArrayContext
-from .impl.pytato import PytatoPyOpenCLArrayContext
+from .impl.pytato import (PytatoPyOpenCLArrayContext,
+                          PytatoJAXArrayContext,
+                          _BasePytatoArrayContext)
+from .impl.jax import EagerJAXArrayContext
 
 from .pytest import (
+        PytestArrayContextFactory,
         PytestPyOpenCLArrayContextFactory,
         pytest_generate_tests_for_array_contexts,
         pytest_generate_tests_for_pyopencl_array_context)
@@ -102,9 +106,12 @@ __all__ = (
         "outer",
 
         "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
+        "PytatoJAXArrayContext", "_BasePytatoArrayContext",
+        "EagerJAXArrayContext",
 
         "make_loopy_program",
 
+        "PytestArrayContextFactory",
         "PytestPyOpenCLArrayContextFactory",
         "pytest_generate_tests_for_array_contexts",
         "pytest_generate_tests_for_pyopencl_array_context"
diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c9318c9be7f5b7d298a1fb4e47015fc87791a5
--- /dev/null
+++ b/arraycontext/impl/jax/__init__.py
@@ -0,0 +1,129 @@
+"""
+.. currentmodule:: arraycontext
+.. autoclass:: EagerJAXArrayContext
+"""
+
+__copyright__ = """
+Copyright (C) 2021 University of Illinois Board of Trustees
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import numpy as np
+
+from typing import Union, Callable, Any
+from pytools.tag import ToTagSetConvertible
+from arraycontext.context import ArrayContext, _ScalarLike
+
+
+class EagerJAXArrayContext(ArrayContext):
+    """
+    A :class:`ArrayContext` that uses
+    :class:`jaxlib.xla_extension.DeviceArrayBase` instances for its base array
+    class and performs all array operations eagerly. See
+    :class:`~arraycontext.PytatoJAXArrayContext` for a lazier version.
+
+    .. note::
+
+        JAX stores a global configuration state in :data:`jax.config`. Callers
+        are expected to maintain those. Most important for scientific computing
+        workloads being ``jax_enable_x64``.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+        from jax.numpy import DeviceArray
+        self.array_types = (DeviceArray, )
+
+    def _get_fake_numpy_namespace(self):
+        from .fake_numpy import EagerJAXFakeNumpyNamespace
+        return EagerJAXFakeNumpyNamespace(self)
+
+    # {{{ ArrayContext interface
+
+    def empty(self, shape, dtype):
+        import jax.numpy as jnp
+        return jnp.empty(shape=shape, dtype=dtype)
+
+    def zeros(self, shape, dtype):
+        import jax.numpy as jnp
+        return jnp.zeros(shape=shape, dtype=dtype)
+
+    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+        import jax
+        return jax.device_put(array)
+
+    def to_numpy(self, array):
+        import jax
+        # jax.device_get can take scalars as well.
+        return jax.device_get(array)
+
+    def call_loopy(self, t_unit, **kwargs):
+        raise NotImplementedError("calling loopy on JAX arrays"
+                                  " not supported. Maybe rewrite"
+                                  " the loopy kernel as numpy-flavored array"
+                                  " operations using ArrayContext.np.")
+
+    def freeze(self, array):
+        return array.block_until_ready()
+
+    def thaw(self, array):
+        return array
+
+    # }}}
+
+    def tag(self, tags: ToTagSetConvertible, array):
+        # Sorry, not capable.
+        return array
+
+    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
+        # TODO: See `jax.experiemental.maps.xmap`, proabably that should be useful?
+        return array
+
+    def clone(self):
+        return type(self)()
+
+    def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
+        return f
+
+    def einsum(self, spec, *args, arg_names=None, tagged=()):
+        import jax.numpy as jnp
+        if arg_names is not None:
+            from warnings import warn
+            warn("'arg_names' don't bear any significance in "
+                 "EagerJAXArrayContext.", stacklevel=2)
+
+        return jnp.einsum(spec, *args)
+
+    @property
+    def permits_inplace_modification(self):
+        return False
+
+    @property
+    def supports_nonscalar_broadcasting(self):
+        return True
+
+    @property
+    def permits_advanced_indexing(self):
+        return True
+
+# vim: foldmethod=marker
diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0466eeeaf5de6b215708d6203b45ba60c60f3e3
--- /dev/null
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -0,0 +1,143 @@
+__copyright__ = """
+Copyright (C) 2021 University of Illinois Board of Trustees
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+from functools import partial, reduce
+
+from arraycontext.fake_numpy import (
+        BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
+        )
+from arraycontext.container.traversal import (
+        rec_multimap_array_container, rec_map_array_container,
+        rec_map_reduce_array_container,
+        )
+from arraycontext.container import NotAnArrayContainerError, serialize_container
+import numpy
+import jax.numpy as jnp
+
+
+class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
+    # Everything is implemented in the base class for now.
+    pass
+
+
+class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
+    """
+    A :mod:`numpy` mimic for :class:`~arraycontext.EagerJAXArrayContext`.
+    """
+    def _get_fake_numpy_linalg_namespace(self):
+        return EagerJAXFakeNumpyLinalgNamespace(self._array_context)
+
+    def __getattr__(self, name):
+        return partial(rec_multimap_array_container, getattr(jnp, name))
+
+    def reshape(self, a, newshape, order="C"):
+        return rec_map_array_container(
+            lambda ary: jnp.reshape(ary, newshape, order=order),
+            a)
+
+    def transpose(self, a, axes=None):
+        return rec_multimap_array_container(jnp.transpose, a, axes)
+
+    def concatenate(self, arrays, axis=0):
+        return rec_multimap_array_container(jnp.concatenate, arrays, axis)
+
+    def where(self, criterion, then, else_):
+        return rec_multimap_array_container(jnp.where, criterion, then, else_)
+
+    def sum(self, a, axis=None, dtype=None):
+        return rec_map_reduce_array_container(sum,
+                                              partial(jnp.sum,
+                                                      axis=axis,
+                                                      dtype=dtype),
+                                              a)
+
+    def min(self, a, axis=None):
+        return rec_map_reduce_array_container(
+                partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
+
+    def max(self, a, axis=None):
+        return rec_map_reduce_array_container(
+                partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
+
+    def stack(self, arrays, axis=0):
+        return rec_multimap_array_container(
+            lambda *args: jnp.stack(arrays=args, axis=axis),
+            *arrays)
+
+    def array_equal(self, a, b):
+        actx = self._array_context
+
+        # NOTE: not all backends support `bool` properly, so use `int8` instead
+        false = actx.from_numpy(numpy.int8(False))
+
+        def rec_equal(x, y):
+            if type(x) != type(y):
+                return false
+
+            try:
+                iterable = zip(serialize_container(x), serialize_container(y))
+            except NotAnArrayContainerError:
+                if x.shape != y.shape:
+                    return false
+                else:
+                    return jnp.all(jnp.equal(x, y))
+            else:
+                return reduce(
+                        jnp.logical_and,
+                        [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable]
+                        )
+
+        return rec_equal(a, b)
+
+    def ravel(self, a, order="C"):
+        """
+        .. warning::
+
+            Since :func:`jax.numpy.reshape` does not support orders `A`` and
+            ``K``, in such cases we fallback to using ``order = C``.
+        """
+        if order in "AK":
+            from warnings import warn
+            warn(f"ravel with order='{order}' not supported by JAX,"
+                 " using order=C.")
+            order = "C"
+
+        return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order),
+                                       a)
+
+    def vdot(self, x, y, dtype=None):
+        from arraycontext import rec_multimap_reduce_array_container
+
+        def _rec_vdot(ary1, ary2):
+            if dtype not in [None, numpy.find_common_type((ary1.dtype,
+                                                           ary2.dtype),
+                                                          ())]:
+                raise NotImplementedError(f"{type(self)} cannot take dtype in"
+                                          " vdot.")
+
+            return jnp.vdot(ary1, ary2)
+
+        return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
+
+    def broadcast_to(self, array, shape):
+        return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
diff --git a/doc/array_context.rst b/doc/array_context.rst
index 680d49bf793ade7aba8622e50f7f4f2734e3d289..c137c0899afa218e195a21369d3856f47b461f7e 100644
--- a/doc/array_context.rst
+++ b/doc/array_context.rst
@@ -19,6 +19,12 @@ Lazy/Deferred evaluation array context based on :mod:`pytato`
 
 .. automodule:: arraycontext.impl.pytato
 
+
+Array context :mod:`jax.numpy`
+-------------------------------------------------------------
+
+.. automodule:: arraycontext.impl.jax
+
 .. _numpy-coverage:
 
 :mod:`numpy` coverage
diff --git a/doc/conf.py b/doc/conf.py
index bee0e10b98630bf89addaabe7931fd5eeaf86478..c91572488c400ff40959254bf745ea1e7052639d 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -24,6 +24,7 @@ intersphinx_mapping = {
     "https://documen.tician.de/loopy": None,
     "https://documen.tician.de/meshmode": None,
     "https://docs.pytest.org/en/latest/": None,
+    "https://jax.readthedocs.io/en/latest/": None,
 }
 
 import sys