hihoCoder 1093-最短路径·三:SPFA算法 #1093 : 最短路径·三:SPFA算法 时间限制:10000ms 单点时限:1000ms 内存限制:256MB 描述 万圣节的晚上,小Hi和小Ho在吃过晚饭之后,来到了一个巨大的鬼屋! 鬼屋中一共有N个地点,分别编号为1..N,这N个地点之间互相有一些道路连通,两个地点之间可能有多条道路连通,但是并不存在一条两端都是同一个地点的道路。 不过这个鬼屋虽然很大,但是其中的道路并不算多,所以小Hi还是希望能够知道从入口到出口的最短距离是多少? 提示:Super Programming Festival Algorithm。 输入 每个测试点(输入文件)有且仅有一组测试数据。 在一组测试数据中: 第1行为4个整数N、M、S、T,分别表示鬼屋中地点的个数和道路的条数,入口(也是一个地点)的编号,出口(同样也是一个地点)的编号。 接下来的M行,每行描述一条道路:其中的第i行为三个整数u_i, v_i, length_i,表明在编号为u_i的地点和编号为v_i的地点之间有一条长度为length_i的道路。 对于100%的数据,满足N<=10^5,M<=10^6, 1 <= length_i <= 10^3, 1 <= S, T <= N, 且S不等于T。 对于100%的数据,满足小Hi和小Ho总是有办法从入口通过地图上标注出来的道路到达出口。 输出 对于每组测试数据,输出一个整数Ans,表示那么小Hi和小Ho为了走出鬼屋至少要走的路程。 样例输入 5 10 3 5 1 2 997 2 3 505 3 4 118 4 5 54 3 5 480 3 4 796 5 2 794 2 5 146 5 4 604 2 5 63 样例输出 172
对于边稀疏的图,SPFA算法可以更高效的求解单源最短路径问题。 假设要求S和T的最短路径,该算法使用一个队列,最开始队列中只有(S,0)—表示当前处于点S,从点S到达该点的距离为0,然后每次从队首取出一个节点(i, L)——表示当前处于点i,从点S到达该点的距离为L,接下来遍历所有从这个节点出发的边(i, j, l)——表示i和j之间有一条长度为l的边,在将(j, L+l)加入到队列之前,先判断队列中是否存在点j,如果存在(j,l’),则比较L+l和l’的大小关系,如果L+l>=l’,那么(j,L+l)这条路就没必要继续搜索下去了,所以不将(j,L+l)加入队列;如果L+l<l’,那么原来的(j,l’)没必要继续搜索下去了,把(j,l’)替换成(j,L+l)即可。如果原队列中不存在j点,则直接把(j,L+l)加入队列。 当队列为空时,(T,ans)就是S到T的距离为ans。所以SPFA在某种程度上来说,就是BFS+剪枝。 SPFA的原理不难,写代码的时候要注意几点:1)N的范围是N<=10^5,如果用数组存储的话,内存够呛,所以建议用vector动态分配;2)每次从队列中弹出一个点时,记得将其在队列中的标记清空,即used[i]=0;3)处理好数据结构问题,因为数据输入中2点之间的距离可能有多个值,你可以在输入的时候只存储最小值,你也可以为了方便先把所有数据都存起来,统一在SPFA算法中进行大小判断,本代码使用第二种策略。
完整代码如下:
#include <iostream>
#include <cstdio>
#include <vector>
#include <queue>
using namespace std;
const int MAX_N = 1e5 + 10;
const int INF = 1e6 + 10;
int n, m, s, t;
//int path[MAX_N][MAX_N];//数组太大,改用vector动态分配
vector<int> path[MAX_N]; //path[i][j]表示第i个点和第path[i][j]个点有路相同,其中j是输入的顺序
vector<int> dist[MAX_N]; //dist[i][j]表示和第i个点相连的第j个点的距离是dist[i][j],path[i][j]和dist[i][j]是通过相同的j来共享数据
int used[MAX_N]; //used[i]表示第i个点是否在队列中
int s_dist[MAX_N]; //s_dist[i]表示第i个点和起始点s的距离
int spfa()
{
int q1 = s;
used[q1] = 1;
for (int i = 1; i <= n; i++)
s_dist[i] = INF;
s_dist[s] = 0;
queue<int> Q;
Q.push(q1);
while (!Q.empty()) {
q1 = Q.front();
Q.pop();
used[q1] = 0; //当弹出队列时记得设置used[i]=0
int qs = path[q1].size(); //和q1点相连的点的个数
for (int i = 0; i < qs; i++) {
int u = path[q1][i];
if (s_dist[q1] + dist[q1][i] < s_dist[u]) //如果u是q1的父节点,肯定不满足这一条,所以不会回溯
{
s_dist[u] = s_dist[q1] + dist[q1][i];
if (used[u] == 0) {
used[u] = 1;
Q.push(u);
}
}
}
}
return s_dist[t];
}
int main()
{
//freopen("input.txt","r",stdin);
scanf("%d%d%d%d", &n, &m, &s, &t);
int u, v, len;
while (m–) {
scanf("%d%d%d", &u, &v, &len); //记录了所有数据,此处还可以优化
path[u].push_back(v);
path[v].push_back(u);
dist[u].push_back(len);
dist[v].push_back(len);
}
printf("%d\n", spfa());
return 0;
}
本代码提交AC,用时208MS,内存19MB。