diff --git a/loopy/buffer_writes.py b/loopy/buffer_writes.py
index 7083e3c269aea7d663c1c38605ecd8b581324b6d..7616dcba59d57997786538d9a67525bdfe07cac1 100644
--- a/loopy/buffer_writes.py
+++ b/loopy/buffer_writes.py
@@ -26,7 +26,8 @@ THE SOFTWARE.
 
 from loopy.array_buffer import (ArrayToBufferMap, NoOpArrayToBufferMap,
         AccessDescriptor)
-from loopy.symbolic import ExpandingIdentityMapper, SubstitutionMapper
+from loopy.symbolic import (get_dependencies, ExpandingIdentityMapper,
+        SubstitutionMapper)
 from pymbolic.mapper.substitutor import make_subst_func
 
 from pymbolic import var
@@ -112,8 +113,7 @@ class ArrayAccessReplacer(ExpandingIdentityMapper):
 
 def buffer_write(kernel, var_name, buffer_inames, init_expression=None,
         store_expression=None, within=None, default_tag="l.auto",
-        temporary_is_local=None, fetch_bounding_box=False,
-        within_inames=()):
+        temporary_is_local=None, fetch_bounding_box=False):
     """
     :arg init_expression: Either *None* (indicating the prior value of the buffered
         array should be read) or an expression optionally involving the
@@ -134,10 +134,8 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None,
         store_expression = parse(store_expression)
 
     if isinstance(buffer_inames, str):
-        buffer_inames = buffer_inames.split(",")
-
-    if isinstance(within_inames, str):
-        within_inames = within_inames.split(",")
+        buffer_inames = [s.strip()
+                for s in buffer_inames.split(",") if s.strip()]
 
     for iname in buffer_inames:
         if iname not in kernel.all_inames():
@@ -170,6 +168,7 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None,
     # }}}
 
     var_name_gen = kernel.get_var_name_generator()
+    within_inames = set()
 
     access_descriptors = []
     for insn in kernel.instructions:
@@ -178,6 +177,9 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None,
 
         for assignee, index in insn.assignees_and_indices():
             if assignee == var_name:
+                within_inames.update(
+                        (get_dependencies(index) & kernel.all_inames())
+                        - buffer_inames_set)
                 access_descriptors.append(
                         AccessDescriptor(
                             identifier=insn.id,
@@ -211,7 +213,7 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None,
         # {{{ find domain to be changed
 
         from loopy.kernel.tools import DomainChanger
-        domch = DomainChanger(kernel, buffer_inames_set)
+        domch = DomainChanger(kernel, buffer_inames_set | within_inames)
 
         if domch.leaf_domain_index is not None:
             # If the sweep inames are at home in parent domains, then we'll add
diff --git a/test/test_fortran.py b/test/test_fortran.py
index c68e963a60585634437d8a80b5eb1cbb1e37e2c2..eeac0b23abe75dffc27375bfc4b66b9708c7e63c 100644
--- a/test/test_fortran.py
+++ b/test/test_fortran.py
@@ -273,7 +273,12 @@ def test_tagged(ctx_factory):
     assert sum(1 for insn in lp.find_instructions(knl, "*$input")) == 2
 
 
-def test_matmul(ctx_factory):
+@pytest.mark.parametrize("buffer_inames", [
+    "",
+    "i_inner",
+    "i_inner,j_inner",
+    ])
+def test_matmul(ctx_factory, buffer_inames):
     fortran_src = """
         subroutine dgemm(m,n,l,a,b,c)
           implicit none
@@ -311,11 +316,8 @@ def test_matmul(ctx_factory):
     knl = lp.precompute(knl, "a_acc", "k_inner,i_inner")
     knl = lp.precompute(knl, "b_acc", "j_inner,k_inner")
 
-    # FIXME: also test
-    # knl = lp.buffer_write(knl, "c", (), init_expression="0",
-    #         store_expression="base+buffer")
-    knl = lp.buffer_write(knl, "c", "i_inner,j_inner", init_expression="0",
-            store_expression="base+buffer", within_inames="i_outer,j_outer")
+    knl = lp.buffer_write(knl, "c", buffer_inames=buffer_inames,
+            init_expression="0", store_expression="base+buffer")
 
     #ctx = ctx_factory()
     #lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5, m=7, l=10))