diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 019b899594ee1670d1ee59654f58b93ff8afacb9..d1efd4a44254a6d829ce63b66bb228cea1efab6a 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -1392,11 +1392,11 @@ def create_temporaries(knl, default_order):
 
 # {{{ determine shapes of temporaries
 
-def find_var_shape(knl, var_name, feed_expression):
-    from loopy.symbolic import AccessRangeMapper, SubstitutionRuleExpander
+def find_shapes_of_vars(knl, var_names, feed_expression):
+    from loopy.symbolic import BatchedAccessRangeMapper, SubstitutionRuleExpander
     submap = SubstitutionRuleExpander(knl.substitutions)
 
-    armap = AccessRangeMapper(knl, var_name)
+    armap = BatchedAccessRangeMapper(knl, var_names)
 
     def run_through_armap(expr, inames):
         armap(submap(expr), inames)
@@ -1404,61 +1404,105 @@ def find_var_shape(knl, var_name, feed_expression):
 
     feed_expression(run_through_armap)
 
-    if armap.access_range is not None:
-        base_indices, shape = list(zip(*[
-                knl.cache_manager.base_index_and_length(
-                    armap.access_range, i)
-                for i in range(armap.access_range.dim(dim_type.set))]))
-    else:
-        if armap.bad_subscripts:
-            raise RuntimeError("cannot determine access range for '%s': "
-                    "undetermined index in subscript(s) '%s'"
-                    % (var_name, ", ".join(
-                            str(i) for i in armap.bad_subscripts)))
+    var_to_base_indices = {}
+    var_to_shape = {}
+    var_to_error = {}
+
+    from loopy.diagnostic import StaticValueFindingError
+
+    for var_name in var_names:
+        access_range = armap.access_ranges[var_name]
+        bad_subscripts = armap.bad_subscripts[var_name]
+
+        if access_range is not None:
+            try:
+                base_indices, shape = list(zip(*[
+                        knl.cache_manager.base_index_and_length(
+                            access_range, i)
+                        for i in range(access_range.dim(dim_type.set))]))
+            except StaticValueFindingError as e:
+                var_to_error[var_name] = str(e)
+                continue
+
+        else:
+            if bad_subscripts:
+                raise RuntimeError("cannot determine access range for '%s': "
+                        "undetermined index in subscript(s) '%s'"
+                        % (var_name, ", ".join(
+                                str(i) for i in bad_subscripts)))
+
+            # no subscripts found, let's call it a scalar
+            base_indices = ()
+            shape = ()
 
-        # no subscripts found, let's call it a scalar
-        base_indices = ()
-        shape = ()
+        var_to_base_indices[var_name] = base_indices
+        var_to_shape[var_name] = shape
 
-    return base_indices, shape
+    return var_to_base_indices, var_to_shape, var_to_error
 
 
 def determine_shapes_of_temporaries(knl):
     new_temp_vars = knl.temporary_variables.copy()
 
     import loopy as lp
-    from loopy.diagnostic import StaticValueFindingError
 
-    new_temp_vars = {}
+    vars_needing_shape_inference = set()
+
     for tv in six.itervalues(knl.temporary_variables):
         if tv.shape is lp.auto or tv.base_indices is lp.auto:
-            def feed_all_expressions(receiver):
-                for insn in knl.instructions:
-                    insn.with_transformed_expressions(
-                            lambda expr: receiver(expr, knl.insn_inames(insn)))
+            vars_needing_shape_inference.add(tv.name)
 
-            def feed_assignee_of_instruction(receiver):
-                for insn in knl.instructions:
-                    for assignee in insn.assignees:
-                        receiver(assignee, knl.insn_inames(insn))
+    def feed_all_expressions(receiver):
+        for insn in knl.instructions:
+            insn.with_transformed_expressions(
+                lambda expr: receiver(expr, knl.insn_inames(insn)))
 
-            try:
-                base_indices, shape = find_var_shape(
-                        knl, tv.name, feed_all_expressions)
-            except StaticValueFindingError as e:
-                warn_with_kernel(knl, "temp_shape_fallback",
-                        "Had to fall back to legacy method of determining "
-                        "shape of temporary '%s' because: %s"
-                        % (tv.name, str(e)))
+    var_to_base_indices, var_to_shape, var_to_error = (
+        find_shapes_of_vars(
+                knl, vars_needing_shape_inference, feed_all_expressions))
+
+    # {{{ fall back to legacy method
+
+    if len(var_to_error) > 0:
+        vars_needing_shape_inference = set(var_to_error.keys())
+
+        from six import iteritems
+        for varname, err in iteritems(var_to_error):
+            warn_with_kernel(knl, "temp_shape_fallback",
+                             "Had to fall back to legacy method of determining "
+                             "shape of temporary '%s' because: %s"
+                             % (varname, err))
+
+        def feed_assignee_of_instruction(receiver):
+            for insn in knl.instructions:
+                for assignee in insn.assignees:
+                    receiver(assignee, knl.insn_inames(insn))
+
+        var_to_base_indices_fallback, var_to_shape_fallback, var_to_error = (
+            find_shapes_of_vars(
+                    knl, vars_needing_shape_inference, feed_assignee_of_instruction))
 
-                base_indices, shape = find_var_shape(
-                        knl, tv.name, feed_assignee_of_instruction)
+        if len(var_to_error) > 0:
+            # No way around errors: propagate an exception upward.
+            formatted_errors = (
+                "\n\n".join("'%s': %s" % (varname, var_to_error[varname])
+                for varname in sorted(var_to_error.keys())))
 
-            if tv.base_indices is lp.auto:
-                tv = tv.copy(base_indices=base_indices)
-            if tv.shape is lp.auto:
-                tv = tv.copy(shape=shape)
+            raise LoopyError("got the following exception(s) trying to find the "
+                    "shape of temporary variables: %s" % formatted_errors)
 
+        var_to_base_indices.update(var_to_base_indices_fallback)
+        var_to_shape.update(var_to_shape_fallback)
+
+    # }}}
+
+    new_temp_vars = {}
+
+    for tv in six.itervalues(knl.temporary_variables):
+        if tv.base_indices is lp.auto:
+            tv = tv.copy(base_indices=var_to_base_indices[tv.name])
+        if tv.shape is lp.auto:
+            tv = tv.copy(shape=var_to_shape[tv.name])
         new_temp_vars[tv.name] = tv
 
     return knl.copy(temporary_variables=new_temp_vars)
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 52fd6e57f92e7f9599a3a0fb4256f97347708303..b14fba5706b83c94a86b66079925939567d60594 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -1471,12 +1471,13 @@ def get_access_range(domain, subscript, assumptions):
 
 # {{{ access range mapper
 
-class AccessRangeMapper(WalkMapper):
-    def __init__(self, kernel, arg_name):
+class BatchedAccessRangeMapper(WalkMapper):
+
+    def __init__(self, kernel, arg_names):
         self.kernel = kernel
-        self.arg_name = arg_name
-        self.access_range = None
-        self.bad_subscripts = []
+        self.arg_names = set(arg_names)
+        self.access_ranges = dict((arg, None) for arg in arg_names)
+        self.bad_subscripts = dict((arg, []) for arg in arg_names)
 
     def map_subscript(self, expr, inames):
         domain = self.kernel.get_inames_domain(inames)
@@ -1484,38 +1485,58 @@ class AccessRangeMapper(WalkMapper):
 
         assert isinstance(expr.aggregate, p.Variable)
 
-        if expr.aggregate.name != self.arg_name:
+        if expr.aggregate.name not in self.arg_names:
             return
 
+        arg_name = expr.aggregate.name
         subscript = expr.index_tuple
 
         if not get_dependencies(subscript) <= set(domain.get_var_dict()):
-            self.bad_subscripts.append(expr)
+            self.bad_subscripts[arg_name].append(expr)
             return
 
         access_range = get_access_range(domain, subscript, self.kernel.assumptions)
 
-        if self.access_range is None:
-            self.access_range = access_range
+        if self.access_ranges[arg_name] is None:
+            self.access_ranges[arg_name] = access_range
         else:
-            if (self.access_range.dim(dim_type.set)
+            if (self.access_ranges[arg_name].dim(dim_type.set)
                     != access_range.dim(dim_type.set)):
                 raise RuntimeError(
                         "error while determining shape of argument '%s': "
                         "varying number of indices encountered"
-                        % self.arg_name)
+                        % arg_name)
 
-            self.access_range = self.access_range | access_range
+            self.access_ranges[arg_name] = (
+                    self.access_ranges[arg_name] | access_range)
 
     def map_linear_subscript(self, expr, inames):
         self.rec(expr.index, inames)
 
-        if expr.aggregate.name == self.arg_name:
-            self.bad_subscripts.append(expr)
+        if expr.aggregate.name in self.arg_names:
+            self.bad_subscripts[expr.aggregate.name].append(expr)
 
     def map_reduction(self, expr, inames):
         return WalkMapper.map_reduction(self, expr, inames | set(expr.inames))
 
+
+class AccessRangeMapper(object):
+
+    def __init__(self, kernel, arg_name):
+        self.arg_name = arg_name
+        self.inner_mapper = BatchedAccessRangeMapper(kernel, [arg_name])
+
+    def __call__(self, expr, inames):
+        return self.inner_mapper(expr, inames)
+
+    @property
+    def access_range(self):
+        return self.inner_mapper.access_ranges[self.arg_name]
+
+    @property
+    def bad_subscripts(self):
+        return self.inner_mapper.bad_subscripts[self.arg_name]
+
 # }}}