diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index 3fc5f2e6eaee36eef7c4a3802df81fd1f4068fa5..bc9481e36067ab07300bf0c10431eed2fac889d0 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -27,12 +27,16 @@ import numpy as np
 
 import jax.numpy as jnp
 
-from arraycontext.container import NotAnArrayContainerError, serialize_container
+from arraycontext.container import (
+    NotAnArrayContainerError,
+    serialize_container,
+)
 from arraycontext.container.traversal import (
     rec_map_array_container,
     rec_map_reduce_array_container,
     rec_multimap_array_container,
 )
+from arraycontext.context import Array, ArrayOrContainer
 from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
 
 
@@ -156,29 +160,35 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_map_reduce_array_container(
             partial(reduce, jnp.logical_or), jnp.any, a)
 
-    def array_equal(self, a, b):
+    def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
         actx = self._array_context
 
         # NOTE: not all backends support `bool` properly, so use `int8` instead
-        true = actx.from_numpy(np.int8(True))
-        false = actx.from_numpy(np.int8(False))
+        true_ary = actx.from_numpy(np.int8(True))
+        false_ary = actx.from_numpy(np.int8(False))
 
         def rec_equal(x, y):
             if type(x) is not type(y):
-                return false
+                return false_ary
 
             try:
-                iterable = zip(serialize_container(x), serialize_container(y))
+                serialized_x = serialize_container(x)
+                serialized_y = serialize_container(y)
             except NotAnArrayContainerError:
                 if x.shape != y.shape:
-                    return false
+                    return false_ary
                 else:
                     return jnp.all(jnp.equal(x, y))
             else:
+                if len(serialized_x) != len(serialized_y):
+                    return false_ary
                 return reduce(
                         jnp.logical_and,
-                        [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable],
-                        true)
+                        [(true_ary if kx_i == ky_i else false_ary)
+                            and rec_equal(x_i, y_i)
+                            for (kx_i, x_i), (ky_i, y_i)
+                            in zip(serialized_x, serialized_y)],
+                        true_ary)
 
         return rec_equal(a, b)
 
diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py
index b7a2335a1a50ca58a235e180b66c264e619ebabd..b305717e15af76e05e4ea6cec909875d1061e6b4 100644
--- a/arraycontext/impl/numpy/fake_numpy.py
+++ b/arraycontext/impl/numpy/fake_numpy.py
@@ -25,14 +25,14 @@ from functools import partial, reduce
 
 import numpy as np
 
-from arraycontext.container import is_array_container
+from arraycontext.container import NotAnArrayContainerError, serialize_container
 from arraycontext.container.traversal import (
-    multimap_reduce_array_container,
     rec_map_array_container,
     rec_map_reduce_array_container,
     rec_multimap_array_container,
     rec_multimap_reduce_array_container,
 )
+from arraycontext.context import Array, ArrayOrContainer
 from arraycontext.fake_numpy import (
     BaseFakeNumpyLinalgNamespace,
     BaseFakeNumpyNamespace,
@@ -127,18 +127,29 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_map_reduce_array_container(partial(reduce, np.logical_and),
                                               lambda subary: np.all(subary), a)
 
-    def array_equal(self, a, b):
-        if type(a) != type(b):
-            return False
-        elif not is_array_container(a):
-            if a.shape != b.shape:
-                return False
-            else:
-                return np.all(np.equal(a, b))
+    def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
+        false_ary = np.array(False)
+        true_ary = np.array(True)
+        if type(a) is not type(b):
+            return false_ary
+
+        try:
+            serialized_x = serialize_container(a)
+            serialized_y = serialize_container(b)
+        except NotAnArrayContainerError:
+            assert isinstance(a, np.ndarray)
+            assert isinstance(b, np.ndarray)
+            return np.array(np.array_equal(a, b))
         else:
-            return multimap_reduce_array_container(partial(reduce,
-                                                           np.logical_and),
-                                                   self.array_equal, a, b)
+            if len(serialized_x) != len(serialized_y):
+                return false_ary
+            return reduce(
+                    np.logical_and,
+                    [(true_ary if kx_i == ky_i else false_ary)
+                        and self.array_equal(x_i, y_i)
+                        for (kx_i, x_i), (ky_i, y_i)
+                        in zip(serialized_x, serialized_y)],
+                    true_ary)
 
     def arange(self, *args, **kwargs):
         return np.arange(*args, **kwargs)
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 59be99e8a2c4099882d7a8816e2872f7d1d0b7fc..848870a9c104b719346a0bb318ec200ba0c93300 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -38,6 +38,7 @@ from arraycontext.container.traversal import (
     rec_multimap_array_container,
     rec_multimap_reduce_array_container,
 )
+from arraycontext.context import Array, ArrayOrContainer
 from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
 from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
 from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
@@ -215,30 +216,40 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
             result = result.get()[()]
         return result
 
-    def array_equal(self, a, b):
+    def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
         actx = self._array_context
         queue = actx.queue
 
         # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead
-        true = actx.from_numpy(np.int8(True))
-        false = actx.from_numpy(np.int8(False))
+        true_ary = actx.from_numpy(np.int8(True))
+        false_ary = actx.from_numpy(np.int8(False))
 
-        def rec_equal(x, y):
+        def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
             if type(x) is not type(y):
-                return false
+                return false_ary
 
             try:
-                iterable = zip(serialize_container(x), serialize_container(y))
+                serialized_x = serialize_container(x)
+                serialized_y = serialize_container(y)
             except NotAnArrayContainerError:
+                assert isinstance(x, cl_array.Array)
+                assert isinstance(y, cl_array.Array)
+
                 if x.shape != y.shape:
-                    return false
+                    return false_ary
                 else:
                     return (x == y).all()
             else:
+                if len(serialized_x) != len(serialized_y):
+                    return false_ary
+
                 return reduce(
                         partial(cl_array.minimum, queue=queue),
-                        [rec_equal(x_i, y_i)for (_, x_i), (_, y_i) in iterable],
-                        true)
+                        [(true_ary if kx_i == ky_i else false_ary)
+                            and rec_equal(x_i, y_i)
+                            for (kx_i, x_i), (ky_i, y_i)
+                            in zip(serialized_x, serialized_y)],
+                        true_ary)
 
         result = rec_equal(a, b)
         if not self._array_context._force_device_scalars:
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index d3d018d6434cefe933c7774225605c2954128e71..c6508e3aba7ac8b9dcf9373c5ca69fb5f6d416d5 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -22,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 from functools import partial, reduce
-from typing import Any
+from typing import Any, cast
 
 import numpy as np
 
@@ -34,6 +34,7 @@ from arraycontext.container.traversal import (
     rec_map_reduce_array_container,
     rec_multimap_array_container,
 )
+from arraycontext.context import Array, ArrayOrContainer
 from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
 from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
 
@@ -171,31 +172,41 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                 partial(reduce, pt.logical_or),
                 lambda subary: pt.any(subary), a)
 
-    def array_equal(self, a, b):
+    def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
         actx = self._array_context
 
         # NOTE: not all backends support `bool` properly, so use `int8` instead
-        true = actx.from_numpy(np.int8(True))
-        false = actx.from_numpy(np.int8(False))
+        true_ary = actx.from_numpy(np.int8(True))
+        false_ary = actx.from_numpy(np.int8(False))
 
-        def rec_equal(x, y):
+        def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array:
             if type(x) is not type(y):
-                return false
+                return false_ary
 
             try:
-                iterable = zip(serialize_container(x), serialize_container(y))
+                serialized_x = serialize_container(x)
+                serialized_y = serialize_container(y)
             except NotAnArrayContainerError:
+                assert isinstance(x, pt.Array)
+                assert isinstance(y, pt.Array)
+
                 if x.shape != y.shape:
-                    return false
+                    return false_ary
                 else:
-                    return pt.all(pt.equal(x, y))
+                    return pt.all(cast(pt.Array, pt.equal(x, y)))
             else:
+                if len(serialized_x) != len(serialized_y):
+                    return false_ary
+
                 return reduce(
                         pt.logical_and,
-                        [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable],
-                        true)
+                        [(true_ary if kx_i == ky_i else false_ary)
+                            and rec_equal(x_i, y_i)
+                            for (kx_i, x_i), (ky_i, y_i)
+                            in zip(serialized_x, serialized_y)],
+                        true_ary)
 
-        return rec_equal(a, b)
+        return cast(Array, rec_equal(a, b))
 
     # }}}