本课时有配套视频讲解,购买后即可观看(永久有效)
Prim堆优化求最小生成树
一、课上练习
编程练习
(暂无)
二、知识总结
✨ 核心思想
朴素 Prim 算法每轮需要 O(V) 时间找到 dist 最小的顶点,总时间 O(V²)。通过使用优先队列(最小堆)维护未选顶点的 dist 值,可以将"找最小"的操作优化到 O(log V),从而将总时间复杂度降低到 O(E log V)。这使得 Prim 算法在稀疏图上也能高效工作。
✨ 算法原理
优化思路:
朴素 Prim 的瓶颈在于每轮从 V-S 中找 dist 最小的顶点(O(V))。使用最小堆后:
- 取出 dist 最小的顶点:O(log V)
- 更新邻居的 dist 值并入堆:O(log V) per edge
- 总复杂度:O(E log V)
与 Dijkstra 堆优化的对比:
两者的代码结构几乎完全相同,唯一的区别是松弛条件:
- Dijkstra:
dist[v] = dist[u] + w(u,v)(距离累加) - Prim:
dist[v] = w(u,v)(仅取边权)
实现细节:
使用 priority_queue 存储 (dist[v], v) 对。由于 C++ 的 priority_queue 是最大堆,可以:
- 存负数:
(-dist[v], v) - 使用
greater<>变成最小堆 - 使用
pair自动按第一个元素排序
取出堆顶时,如果该顶点已经在 MST 中,直接跳过(懒删除策略)。
✨ 代码实现
1#include <iostream>
2#include <vector>
3#include <queue>
4#include <cstring>
5using namespace std;
6
7const int MAXN = 200005;
8const int INF = 0x3f3f3f3f;
9
10struct Edge {
11 int to, weight;
12};
13
14vector<Edge> adj[MAXN]; // 邻接表
15int dist[MAXN]; // 到已选集合的最小边权
16bool inMST[MAXN]; // 是否已加入MST
17int n, m;
18
19// 优先队列中的元素:(距离, 节点)
20typedef pair<int, int> PII;
21
22// Prim堆优化,返回MST总权值
23long long primHeap() {
24 // 最小堆
25 priority_queue<PII, vector<PII>, greater<PII>> pq;
26
27 memset(dist, 0x3f, sizeof(dist));
28 memset(inMST, false, sizeof(inMST));
29
30 dist[1] = 0;
31 pq.push({0, 1}); // (距离, 节点)
32 long long totalWeight = 0;
33 int cnt = 0; // 已加入MST的节点数
34
35 while (!pq.empty() && cnt < n) {
36 auto [d, u] = pq.top();
37 pq.pop();
38
39 // 如果u已在MST中,跳过(懒删除)
40 if (inMST[u]) continue;
41
42 // 将u加入MST
43 inMST[u] = true;
44 totalWeight += d;
45 cnt++;
46
47 // 遍历u的所有邻边
48 for (auto& e : adj[u]) {
49 int v = e.to, w = e.weight;
50 // 如果v未在MST中,且通过u能获得更小的连接边权
51 if (!inMST[v] && w < dist[v]) {
52 dist[v] = w;
53 pq.push({w, v}); // 将新的距离入堆
54 }
55 }
56 }
57
58 // 检查是否所有节点都加入了MST
59 if (cnt < n) return -1; // 图不连通
60 return totalWeight;
61}
62
63int main() {
64 ios::sync_with_stdio(false);
65 cin.tie(nullptr);
66
67 cin >> n >> m;
68 for (int i = 0; i < m; i++) {
69 int u, v, w;
70 cin >> u >> v >> w;
71 adj[u].push_back({v, w});
72 adj[v].push_back({u, w});
73 }
74
75 long long ans = primHeap();
76 if (ans == -1) {
77 cout << "orz" << endl; // 图不连通
78 } else {
79 cout << ans << endl;
80 }
81
82 return 0;
83}1#include <iostream>
2#include <vector>
3#include <queue>
4#include <cstring>
5using namespace std;
6
7const int MAXN = 200005;
8const int INF = 0x3f3f3f3f;
9
10vector<pair<int,int>> adj[MAXN]; // (终点, 边权)
11int dist[MAXN];
12bool inMST[MAXN];
13int from[MAXN]; // 记录v是从哪个节点加入的
14int n, m;
15
16typedef pair<int,int> PII;
17
18struct MSTEdge {
19 int u, v, w;
20};
21
22pair<long long, vector<MSTEdge>> primWithEdges() {
23 priority_queue<PII, vector<PII>, greater<PII>> pq;
24 memset(dist, 0x3f, sizeof(dist));
25 memset(inMST, false, sizeof(inMST));
26 memset(from, -1, sizeof(from));
27
28 dist[1] = 0;
29 pq.push({0, 1});
30 long long total = 0;
31 int cnt = 0;
32 vector<MSTEdge> edges;
33
34 while (!pq.empty() && cnt < n) {
35 auto [d, u] = pq.top();
36 pq.pop();
37
38 if (inMST[u]) continue;
39
40 inMST[u] = true;
41 total += d;
42 cnt++;
43
44 // 记录边
45 if (from[u] != -1) {
46 edges.push_back({from[u], u, d});
47 }
48
49 for (auto& [v, w] : adj[u]) {
50 if (!inMST[v] && w < dist[v]) {
51 dist[v] = w;
52 from[v] = u; // v从u连过来
53 pq.push({w, v});
54 }
55 }
56 }
57
58 if (cnt < n) return {-1, {}};
59 return {total, edges};
60}
61
62int main() {
63 ios::sync_with_stdio(false);
64 cin.tie(nullptr);
65
66 cin >> n >> m;
67 for (int i = 0; i < m; i++) {
68 int u, v, w;
69 cin >> u >> v >> w;
70 adj[u].push_back({v, w});
71 adj[v].push_back({u, w});
72 }
73
74 auto [weight, edges] = primWithEdges();
75 if (weight == -1) {
76 cout << "图不连通" << endl;
77 } else {
78 cout << "MST总权值:" << weight << endl;
79 for (auto& e : edges) {
80 cout << e.u << " -- " << e.v
81 << " (权值 " << e.w << ")" << endl;
82 }
83 }
84
85 return 0;
86}1#include <iostream>
2#include <vector>
3#include <queue>
4#include <cstring>
5using namespace std;
6
7const int MAXN = 200005;
8const int INF = 0x3f3f3f3f;
9
10vector<pair<int,int>> adj[MAXN];
11int n, m;
12
13// 朴素版 O(V²) - 适合稠密图
14long long primNaive() {
15 vector<int> dist(n + 1, INF);
16 vector<bool> vis(n + 1, false);
17 dist[1] = 0;
18 long long total = 0;
19
20 for (int i = 0; i < n; i++) {
21 int u = -1;
22 for (int j = 1; j <= n; j++)
23 if (!vis[j] && (u == -1 || dist[j] < dist[u]))
24 u = j;
25 if (dist[u] == INF) return -1;
26 vis[u] = true;
27 total += dist[u];
28 for (auto& [v, w] : adj[u])
29 if (!vis[v] && w < dist[v])
30 dist[v] = w;
31 }
32 return total;
33}
34
35// 堆优化版 O(E log V) - 适合稀疏图
36long long primHeap() {
37 vector<int> dist(n + 1, INF);
38 vector<bool> vis(n + 1, false);
39 priority_queue<pair<int,int>, vector<pair<int,int>>,
40 greater<pair<int,int>>> pq;
41
42 dist[1] = 0;
43 pq.push({0, 1});
44 long long total = 0;
45 int cnt = 0;
46
47 while (!pq.empty() && cnt < n) {
48 auto [d, u] = pq.top();
49 pq.pop();
50 if (vis[u]) continue;
51 vis[u] = true;
52 total += d;
53 cnt++;
54 for (auto& [v, w] : adj[u])
55 if (!vis[v] && w < dist[v]) {
56 dist[v] = w;
57 pq.push({w, v});
58 }
59 }
60 return cnt < n ? -1 : total;
61}
62
63int main() {
64 ios::sync_with_stdio(false);
65 cin.tie(nullptr);
66
67 cin >> n >> m;
68 for (int i = 0; i < m; i++) {
69 int u, v, w;
70 cin >> u >> v >> w;
71 adj[u].push_back({v, w});
72 adj[v].push_back({u, w});
73 }
74
75 cout << primHeap() << endl;
76 return 0;
77}✨ 执行示例
以下面的图为例(6 个顶点,8 条边):
1 --1-- 2 --6-- 3
| /| |
4 3 5 2
| / | |
4 --2-- 5 --8-- 6边:(1,2,1), (1,4,4), (2,4,3), (2,5,5), (2,3,6), (3,6,2), (4,5,2), (5,6,8)
堆优化 Prim 执行过程:
1初始: pq = {(0,1)}
2
3取出(0,1): 加入1, total=0, cnt=1
4 更新: dist[2]=1→入堆, dist[4]=4→入堆
5 pq = {(1,2), (4,4)}
6
7取出(1,2): 加入2, total=1, cnt=2
8 更新: dist[3]=6→入堆, dist[4]=min(4,3)=3→入堆, dist[5]=5→入堆
9 pq = {(3,4), (4,4), (5,5), (6,3)}
10
11取出(3,4): 加入4, total=4, cnt=3
12 更新: dist[5]=min(5,2)=2→入堆
13 pq = {(2,5), (4,4), (5,5), (6,3)}
14
15取出(2,5): 加入5, total=6, cnt=4
16 更新: dist[6]=min(INF,8)=8→入堆
17 pq = {(4,4), (5,5), (6,3), (8,6)}
18
19取出(4,4): 节点4已在MST,跳过(懒删除)
20取出(5,5): 节点5已在MST,跳过(懒删除)
21
22取出(6,3): 加入3, total=12, cnt=5
23 更新: dist[6]=min(8,2)=2→入堆
24 pq = {(2,6), (8,6)}
25
26取出(2,6): 加入6, total=14, cnt=6
27 全部加入完毕
28
29MST总权值 = 14
30MST的边: (1,2,1), (2,4,3), (4,5,2), (2,3,6), (3,6,2)注意懒删除: 节点 4 对应了两个堆元素 (3,4) 和 (4,4)。当 (3,4) 被取出时,4 加入 MST;之后 (4,4) 被取出时,发现 4 已在 MST 中,直接跳过。
✨ 解题步骤详解
- 建图:使用邻接表存储(堆优化必须用邻接表)
- 初始化:dist 数组、inMST 数组、优先队列
- 起始点入堆:
pq.push({0, start}) - 主循环:
- 取堆顶 (d, u)
- 如果 u 已在 MST 中,跳过
- 否则加入 MST,累加权值
- 遍历 u 的邻边,更新 dist 并入堆
- 判断连通性:检查加入 MST 的节点数是否为 n
选择朴素版还是堆优化版?
- V 较小(
<5000),E 接近 V²:朴素版 O(V²) 更快 - V 较大,E 远小于 V²:堆优化版 O(E log V) 更快
- 竞赛中一般用堆优化版或 Kruskal
✨ 常见错误
- 忘记懒删除:取出堆顶时必须检查该节点是否已在 MST 中
- 堆的排序方向错误:C++ priority_queue 默认是最大堆,需要用
greater<>或存负数 - dist 更新条件写错:应该是
w < dist[v](边权),不是dist[u] + w < dist[v](那是 Dijkstra) - 堆中存了过多冗余元素:虽然不影响正确性,但可能导致内存过大
- 邻接表存储有向边:无向图的每条边需要存两次
- 起始点的处理:起始点 dist=0,入堆后被取出时 total+=0,不影响结果
✨ 算法评价
| 版本 | 时间复杂度 | 空间复杂度 | 适用场景 |
|---|---|---|---|
| 朴素 Prim | O(V²) | O(V²) | 稠密图 |
| 堆优化 Prim | O(E log V) | O(V + E) | 稀疏图 |
| Kruskal | O(E log E) | O(V + E) | 稀疏图 |
三种 MST 算法的选择:
| 条件 | 推荐算法 |
|---|---|
| 稠密图(E 接近 V²) | 朴素 Prim O(V²) |
| 稀疏图 | Kruskal O(E log E) 或堆优化 Prim |
| 需要记录 MST 的边 | 两者都可以 |
| 只需要 MST 总权值 | 两者都可以 |
| 边已排序 | Kruskal(跳过排序步骤) |
堆优化 Prim 和 Kruskal 在稀疏图上的时间复杂度接近(O(E log V) vs O(E log E)),实际表现差不多。竞赛中 Kruskal 因为实现更简洁而使用更广泛,但堆优化 Prim 也是重要的工具。