diff --git a/test/test_fusion.py b/test/test_fusion.py
index 98dbe241ec110915ae1d85a2326573f0f959f008..8e28fb3493e9517236d386989a51f3d9dfe440ef 100644
--- a/test/test_fusion.py
+++ b/test/test_fusion.py
@@ -52,3 +52,124 @@ def test_two_kernel_fusion(ctx_factory):
     knl = lp.fuse_kernels([knla, knlb], data_flow=[("out", 0, 1)])
     evt, (out,) = knl(queue)
     np.testing.assert_allclose(out.get(), np.arange(100, 110))
+
+
+def test_write_block_matrix_fusion(ctx_factory):
+    """
+    A slightly more complicated fusion test, where all
+    sub-kernels write into the same global matrix, but
+    in well-defined separate blocks. This tests makes sure
+    data flow specification is preserved during fusion for
+    matrix-assembly-like programs.
+    """
+
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    def init_global_mat_prg():
+        return lp.make_kernel(
+            [
+                "{[idof]: 0 <= idof < n}",
+                "{[jdof]: 0 <= jdof < m}"
+            ],
+            """
+                result[idof, jdof]  = 0 {id=init}
+            """,
+            [
+                lp.GlobalArg("result", None,
+                    shape="n, m",
+                    offset=lp.auto),
+                lp.ValueArg("n, m", np.int32),
+                "...",
+            ],
+            options=lp.Options(
+                return_dict=True
+            ),
+            default_offset=lp.auto,
+            name="init_a_global_matrix",
+        )
+
+    def write_into_mat_prg():
+        return lp.make_kernel(
+            [
+                "{[idof]: 0 <= idof < ndofs}",
+                "{[jdof]: 0 <= jdof < mdofs}"
+            ],
+            """
+                result[offset_i + idof, offset_j + jdof] = mat[idof, jdof]
+            """,
+            [
+                lp.GlobalArg("result", None,
+                    shape="n, m",
+                    offset=lp.auto),
+                lp.ValueArg("n, m", np.int32),
+                lp.GlobalArg("mat", None,
+                    shape="ndofs, mdofs",
+                    offset=lp.auto),
+                lp.ValueArg("offset_i", np.int32),
+                lp.ValueArg("offset_j", np.int32),
+                "...",
+            ],
+            options=lp.Options(
+                return_dict=True
+            ),
+            default_offset=lp.auto,
+            name="write_into_global_matrix",
+        )
+
+    # Construct a 2x2 diagonal matrix with
+    # random 5x5 blocks on the block-diagonal,
+    # and zeros elsewhere
+    n = 10
+    block_n = 5
+    mat1 = np.random.randn(block_n, block_n)
+    mat2 = np.random.randn(block_n, block_n)
+    answer = np.block([[mat1, np.zeros((block_n, block_n))],
+                      [np.zeros((block_n, block_n)), mat2]])
+    kwargs = {"n": n, "m": n}
+
+    # Do some renaming of individual programs before fusion
+    kernels = [init_global_mat_prg()]
+    for idx, (offset, mat) in enumerate([(0, mat1), (block_n, mat2)]):
+        knl = lp.rename_argument(write_into_mat_prg(), "mat", f"mat_{idx}")
+        kwargs[f"mat_{idx}"] = mat
+
+        for iname in knl.all_inames():
+            knl = lp.rename_iname(knl, iname, f"{iname}_{idx}")
+
+        knl = lp.rename_argument(knl, "ndofs", f"ndofs_{idx}")
+        knl = lp.rename_argument(knl, "mdofs", f"mdofs_{idx}")
+        kwargs[f"ndofs_{idx}"] = block_n
+        kwargs[f"mdofs_{idx}"] = block_n
+
+        knl = lp.rename_argument(knl, "offset_i", f"offset_i_{idx}")
+        knl = lp.rename_argument(knl, "offset_j", f"offset_j_{idx}")
+        kwargs[f"offset_i_{idx}"] = offset
+        kwargs[f"offset_j_{idx}"] = offset
+
+        kernels.append(knl)
+
+    fused_knl = lp.fuse_kernels(
+        kernels,
+        data_flow=[("result", 0, 1), ("result", 1, 2)],
+    )
+    fused_knl = lp.add_nosync(
+        fused_knl,
+        "global",
+        "writes:result",
+        "writes:result",
+        bidirectional=True,
+        force=True
+    )
+    evt, result = fused_knl(queue, **kwargs)
+    result = result["result"]
+    np.testing.assert_allclose(result, answer)
+
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) > 1:
+        exec(sys.argv[1])
+    else:
+        from pytest import main
+        main([__file__])