diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py
index 75e021f5cfac035f9dd96021548e255a3512b37f..2ac1f7012724aafd071aeb2c1dda66e89106e23b 100644
--- a/grudge/geometry/metrics.py
+++ b/grudge/geometry/metrics.py
@@ -70,7 +70,7 @@ from grudge.dof_desc import (
 from pymbolic.geometric_algebra import MultiVector
 
 from pytools.obj_array import make_obj_array
-from pytools import memoize_on_first_arg
+from pytools import memoize_in
 
 
 # {{{ Metric computations
@@ -373,7 +373,6 @@ def inverse_metric_derivative(
     return result
 
 
-@memoize_on_first_arg
 def inverse_surface_metric_derivative(
         actx: ArrayContext, dcoll: DiscretizationCollection,
         rst_axis, xyz_axis, dd=None):
@@ -390,18 +389,25 @@ def inverse_surface_metric_derivative(
         inverse metric derivative at each nodal coordinate.
     """
     dim = dcoll.dim
-    if dcoll.ambient_dim == dim:
-        imd = inverse_metric_derivative(
-            actx, dcoll, rst_axis, xyz_axis, dd=dd
-        )
-    else:
-        inv_form1 = inverse_first_fundamental_form(actx, dcoll, dim=dim, dd=dd)
-        imd = sum(
-            inv_form1[rst_axis, d]*forward_metric_nth_derivative(
-                actx, dcoll, xyz_axis, d, dd=dd
-            ) for d in range(dim)
-        )
-    return imd
+    ambient_dim = dcoll.ambient_dim
+
+    @memoize_in(dcoll, (inverse_surface_metric_derivative, dd,
+                        "inv_metric_deriv_rst%s_xyz%s_adim%s_gdim%s"
+                        % (rst_axis, xyz_axis, ambient_dim, dim)))
+    def _inv_surf_metric_deriv():
+        if ambient_dim == dim:
+            imd = inverse_metric_derivative(
+                actx, dcoll, rst_axis, xyz_axis, dd=dd
+            )
+        else:
+            inv_form1 = inverse_first_fundamental_form(actx, dcoll, dim=dim, dd=dd)
+            imd = sum(
+                inv_form1[rst_axis, d]*forward_metric_nth_derivative(
+                    actx, dcoll, xyz_axis, d, dd=dd
+                ) for d in range(dim)
+            )
+        return imd
+    return _inv_surf_metric_deriv()
 
 
 def _signed_face_ones(
@@ -480,7 +486,6 @@ def pseudoscalar(actx: ArrayContext, dcoll: DiscretizationCollection,
     ).project_max_grade()
 
 
-@memoize_on_first_arg
 def area_element(
         actx: ArrayContext, dcoll: DiscretizationCollection, dim=None, dd=None
         ) -> DOFArray:
@@ -494,10 +499,14 @@ def area_element(
     :returns: a :class:`~meshmode.dof_array.DOFArray` containing the transformed
         volumes for each element.
     """
-
-    return actx.np.sqrt(
-        pseudoscalar(actx, dcoll, dim=dim, dd=dd).norm_squared()
-    )
+    @memoize_in(dcoll, (area_element, dd,
+                        "area_elements_adim%s_gdim%s"
+                        % (dcoll.ambient_dim, dim)))
+    def _area_elements():
+        return actx.np.sqrt(
+            pseudoscalar(actx, dcoll, dim=dim, dd=dd).norm_squared()
+        )
+    return _area_elements()
 
 # }}}