diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 04f5f4176a1da87f5b6f970ac16d52071e28d9f0..eacdd2a661265d391cb098f1f1fbdec4461b1097 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -256,6 +256,17 @@ def get_axis_0_ranking(kernel, insn):
 
     # }}}
 
+    # {{{ figure out axis 0 candidates
+
+    from loopy.kernel import AutoLocalIndexTagBase
+    axis0_candidates = set(
+            iname
+            for iname in insn.all_inames()
+            if isinstance(kernel.iname_to_tag.get(iname),
+                AutoLocalIndexTagBase))
+
+    # }}}
+
     # {{{ figure out which iname should get mapped to local axis 0
 
     # maps inames to vote counts
@@ -265,6 +276,8 @@ def get_axis_0_ranking(kernel, insn):
 
     from pytools import argmin2
 
+    saw_relevant_access = False
+
     for aae in global_ary_acc_exprs:
         index_expr = aae.index
         if not isinstance(index_expr, tuple):
@@ -277,6 +290,8 @@ def get_axis_0_ranking(kernel, insn):
         if ary_strides is None and len(index_expr) == 1:
             ary_strides = (1,)
 
+        # {{{ construct iname_to_stride
+
         iname_to_stride = {}
         for iexpr_i, stride in zip(index_expr, ary_strides):
             coeffs = CoefficientCollector()(iexpr_i)
@@ -287,6 +302,11 @@ def get_axis_0_ranking(kernel, insn):
                     if old_stride is None or new_stride < old_stride:
                         iname_to_stride[var_name] = new_stride
 
+        # }}}
+
+        if set(iname_to_stride.keys()) & axis0_candidates:
+            saw_relevant_access = True
+
         if iname_to_stride:
             from pymbolic import evaluate
             least_stride_iname, least_stride = argmin2((
@@ -304,9 +324,12 @@ def get_axis_0_ranking(kernel, insn):
                     vote_count_for_l0.get(least_stride_iname, 0)
                     + vote_strength)
 
-    return sorted((iname for iname in insn.all_inames()),
-            key=lambda iname: vote_count_for_l0.get(iname, 0),
-            reverse=True)
+    if saw_relevant_access:
+        return sorted((iname for iname in insn.all_inames()),
+                key=lambda iname: vote_count_for_l0.get(iname, 0),
+                reverse=True)
+    else:
+        return None
 
     # }}}
 
@@ -314,20 +337,31 @@ def get_axis_0_ranking(kernel, insn):
 
 # {{{ assign automatic axes
 
-def assign_automatic_axes(kernel, only_axis_0=True):
+def assign_automatic_axes(kernel, phase="axis0", local_size=None):
     from loopy.kernel import (AutoLocalIndexTagBase, LocalIndexTag,
             UnrollTag)
 
-    global_size, local_size = kernel.get_grid_sizes_as_exprs(
-            ignore_auto=True)
+    # Realize that at this point in time, axis lengths are already
+    # fixed. So we compute them once and pass them to our recursive
+    # copies.
+
+    if local_size is None:
+        _, local_size = kernel.get_grid_sizes_as_exprs(
+                ignore_auto=True)
+
+    # {{{ axis assignment helper function
 
     def assign_axis(iname, axis=None):
+        """Assign iname to local axis *axis* and start over by calling
+        the surrounding function assign_automatic_axes.
+
+        If *axis* is None, find a suitable axis automatically.
+        """
         desired_length = kernel.get_constant_iname_length(iname)
 
         if axis is None:
             # {{{ find a suitable axis
 
-            # find already assigned local axes (to avoid them)
             shorter_possible_axes = []
             test_axis = 0
             while True:
@@ -345,10 +379,14 @@ def assign_automatic_axes(kernel, only_axis_0=True):
                     axis = test_axis
                     break
 
-            # longest first
-            shorter_possible_axes.sort(key=lambda ax: local_size[ax])
+            # The loop above will find an unassigned local axis
+            # that has enough 'room' for the iname. In the same traversal,
+            # it also finds theoretically assignable axes that are shorter,
+            # in the variable shorter_possible_axes.
 
             if axis is None and shorter_possible_axes:
+                # sort as longest first
+                shorter_possible_axes.sort(key=lambda ax: local_size[ax])
                 axis = shorter_possible_axes[0]
 
             # }}}
@@ -363,12 +401,23 @@ def assign_automatic_axes(kernel, only_axis_0=True):
                         split_dimension(kernel, iname, inner_length=local_size[axis],
                             outer_tag=UnrollTag(), inner_tag=new_tag,
                             do_tagged_check=False),
-                        only_axis_0=only_axis_0)
+                        phase=phase, local_size=local_size)
 
         new_iname_to_tag = kernel.iname_to_tag.copy()
         new_iname_to_tag[iname] = new_tag
         return assign_automatic_axes(kernel.copy(iname_to_tag=new_iname_to_tag),
-                only_axis_0=only_axis_0)
+                phase=phase, local_size=local_size)
+
+    # }}}
+
+    # {{{ main assignment loop
+
+    # assignment proceeds in two phases:
+
+    # - "axis0": Only axis 0 is assigned on instructions that carry out
+    #   global array access based on l.auto axes
+    #
+    # - "rest": All other l.auto axes are assigned haphazardly.
 
     for insn in kernel.instructions:
         auto_axis_inames = [
@@ -388,28 +437,32 @@ def assign_automatic_axes(kernel, only_axis_0=True):
                 assigned_local_axes.add(tag.axis)
 
         if 0 < len(local_size) and 0 not in assigned_local_axes:
-            for axis0_iname in get_axis_0_ranking(kernel, insn):
-                axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
-                if isinstance(axis0_iname_tag, AutoLocalIndexTagBase):
-                    return assign_axis(axis0_iname, 0)
-
-        if only_axis_0:
+            axis0_ranking = get_axis_0_ranking(kernel, insn)
+            if axis0_ranking is not None:
+                for axis0_iname in axis0_ranking:
+                    axis0_iname_tag = kernel.iname_to_tag.get(axis0_iname)
+                    if isinstance(axis0_iname_tag, AutoLocalIndexTagBase):
+                        return assign_axis(axis0_iname, 0)
+
+        if phase == "axis0":
             continue
 
         # assign longest auto axis inames first
         auto_axis_inames.sort(key=kernel.get_constant_iname_length, reverse=True)
 
-        next_axis = 0
         if auto_axis_inames:
             return assign_axis(auto_axis_inames.pop())
 
+    # }}}
+
     # We've seen all instructions and not punted to recursion/restart because
     # of a new axis assignment.
 
-    if only_axis_0:
+    if phase == "axis0":
         # If we were only assigining axis 0, then assign all the remaining
         # axes next.
-        return assign_automatic_axes(kernel, only_axis_0=False)
+        return assign_automatic_axes(kernel, phase="rest",
+                local_size=local_size)
     else:
         # If were already assigning all axes and got here, we're now done.
         # All automatic axes are assigned.