diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index 785feafba74cc88b4b98bf13daf8487e57e03a62..18e6cefc2c07ef8bfe614076be5f5728e101f98b 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -26,7 +26,11 @@ from functools import reduce
 import pymbolic.primitives as prim
 from pymbolic import parse
 from pytools.lex import ParseError
-from pymbolic.mapper import IdentityMapper
+from testlib import generate_random_expression
+
+
+from pymbolic.mapper import IdentityMapper, WalkMapper
+from pymbolic.mapper.dependency import DependencyMapper, CachedDependencyMapper
 
 import logging
 logger = logging.getLogger(__name__)
@@ -332,8 +336,6 @@ def test_mappers():
     for expr in [
             f(x, (y, z), name=z**2)
             ]:
-        from pymbolic.mapper import WalkMapper
-        from pymbolic.mapper.dependency import DependencyMapper
         str(expr)
         IdentityMapper()(expr)
         WalkMapper()(expr)
@@ -347,7 +349,6 @@ def test_mappers():
 
 def test_func_dep_consistency():
     from pymbolic import var
-    from pymbolic.mapper.dependency import DependencyMapper
     f = var("f")
     x = var("x")
     dep_map = DependencyMapper(include_calls="descend_args")
@@ -863,6 +864,70 @@ def test_equality_complexity():
 # }}}
 
 
+# {{{ test_cached_mapper_memoizes
+
+class InCacheVerifier(WalkMapper):
+    def __init__(self, cached_mapper, walk_call_functions=True):
+        super().__init__()
+        self.cached_mapper = cached_mapper
+        self.walk_call_functions = walk_call_functions
+
+    def post_visit(self, expr):
+        if isinstance(expr, prim.Expression):
+            assert (self.cached_mapper.get_cache_key(expr)
+                    in self.cached_mapper._cache)
+
+    def map_call(self, expr):
+        if not self.visit(expr):
+            return
+
+        if self.walk_call_functions:
+            self.rec(expr.function)
+
+        for child in expr.parameters:
+            self.rec(child)
+
+        self.post_visit(expr)
+
+
+def test_cached_mapper_memoizes():
+    from testlib import (AlwaysFlatteningIdentityMapper,
+                         AlwaysFlatteningCachedIdentityMapper)
+    ntests = 40
+    for i in range(ntests):
+        expr = generate_random_expression(seed=(5+i))
+
+        # {{{ always flattening identity mapper
+
+        # Note: Prefer AlwaysFlatteningIdentityMapper over IdentityMapper as
+        # the flattening logic in IdentityMapper checks for identity across
+        # traversal results => leading to discrepancy b/w
+        # 'CachedIdentityMapper' and 'IdentityMapper'
+
+        cached_mapper = AlwaysFlatteningCachedIdentityMapper()
+        uncached_mapper = AlwaysFlatteningIdentityMapper()
+        assert uncached_mapper(expr) == cached_mapper(expr)
+        verifier = InCacheVerifier(cached_mapper)
+        verifier(expr)
+
+        # }}}
+
+        # {{{ dependency mapper
+
+        mapper = DependencyMapper(include_calls="descend_args")
+        cached_mapper = CachedDependencyMapper(include_calls="descend_args")
+        assert cached_mapper(expr) == mapper(expr)
+        verifier = InCacheVerifier(cached_mapper,
+                                   # dep. mapper does not go over functions
+                                   walk_call_functions=False
+                                   )
+        verifier(expr)
+
+        # }}}
+
+# }}}
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
diff --git a/test/testlib.py b/test/testlib.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd75c444e56dcf1670ea6ba7376cbb354c07ce0d
--- /dev/null
+++ b/test/testlib.py
@@ -0,0 +1,125 @@
+__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+import numpy as np
+
+import pymbolic.primitives as prim
+from dataclasses import dataclass, replace
+from pytools import UniqueNameGenerator
+from pymbolic.mapper import IdentityMapper, CachedIdentityMapper
+
+
+@dataclass(frozen=True, eq=True)
+class RandomExpressionGeneratorContext:
+    rng: np.random.Generator
+    vng: UniqueNameGenerator
+    current_depth: int
+    max_depth: int
+
+    def with_increased_depth(self):
+        return replace(self, current_depth=self.current_depth+1)
+
+
+def _generate_random_expr_inner(
+        context: RandomExpressionGeneratorContext) -> prim.Expression:
+
+    if context.current_depth >= context.max_depth:
+        # force expression to be a leaf type
+        return context.rng.integers(0, 42)
+
+    bucket = context.rng.integers(0, 100) / 100.0
+
+    # {{{ set some distribution of expression types
+
+    # 'weight' is proportional to the probability of seeing an expression type
+    weights = [1, 1, 1, 1, 1]
+    expr_types = [prim.Variable, prim.Sum, prim.Product, prim.Quotient,
+                  prim.Call]
+    assert len(weights) == len(expr_types)
+
+    # }}}
+
+    buckets = np.cumsum(weights, dtype="float64")/np.sum(weights)
+
+    expr_type = expr_types[np.searchsorted(buckets, bucket)]
+
+    if expr_type == prim.Variable:
+        return prim.Variable(context.vng("x"))
+    elif expr_type in [prim.Sum, prim.Product]:
+        left = _generate_random_expr_inner(context.with_increased_depth())
+        right = _generate_random_expr_inner(context.with_increased_depth())
+        return expr_type((left, right))
+    elif expr_type == prim.Quotient:
+        num = _generate_random_expr_inner(context.with_increased_depth())
+        den = _generate_random_expr_inner(context.with_increased_depth())
+        return prim.Quotient(num, den)
+    elif expr_type == prim.Quotient:
+        num = _generate_random_expr_inner(context.with_increased_depth())
+        den = _generate_random_expr_inner(context.with_increased_depth())
+        return prim.Quotient(num, den)
+    elif expr_type == prim.Call:
+        nargs = 3
+        return prim.Variable(context.vng("f"))(
+            *[_generate_random_expr_inner(context.with_increased_depth())
+              for _ in range(nargs)])
+    else:
+        raise NotImplementedError(expr_type)
+
+
+def generate_random_expression(seed: int, max_depth: int = 8) -> prim.Expression:
+    from numpy.random import default_rng
+    rng = default_rng(seed)
+    vng = UniqueNameGenerator()
+
+    context = RandomExpressionGeneratorContext(rng,
+                                               vng=vng,
+                                               max_depth=max_depth,
+                                               current_depth=0)
+
+    return _generate_random_expr_inner(context)
+
+
+# {{{ custom mappers for tests
+
+class _AlwaysFlatteningMixin:
+    def map_sum(self, expr, *args, **kwargs):
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        from pymbolic.primitives import flattened_sum
+        return flattened_sum(children)
+
+    def map_product(self, expr, *args, **kwargs):
+        children = [self.rec(child, *args, **kwargs) for child in expr.children]
+        from pymbolic.primitives import flattened_product
+        return flattened_product(children)
+
+
+class AlwaysFlatteningIdentityMapper(_AlwaysFlatteningMixin,
+                                     IdentityMapper):
+    pass
+
+
+class AlwaysFlatteningCachedIdentityMapper(_AlwaysFlatteningMixin,
+                                           CachedIdentityMapper):
+    pass
+
+# }}}