diff --git a/pytools/__init__.py b/pytools/__init__.py
index c744b3274e6ab81c093c2a3610576c62b4f9dba0..ed02ed0af734093c34a2b700ca89bb64649e4e18 100644
--- a/pytools/__init__.py
+++ b/pytools/__init__.py
@@ -1072,6 +1072,9 @@ def wandering_element(length, wanderer=1, landscape=0):
 
 
 def indices_in_shape(shape):
+    if isinstance(shape, int):
+        shape = (shape,)
+
     if len(shape) == 0:
         yield ()
     elif len(shape) == 1: