LCA + DisjointSet
一個樹上路徑們並查的故事。
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
typedef vector<int> VI;
typedef vector<VI> VVI;
class bst{
public:
bst(int n, int* _val){
base = (1<<(int)(ceil(log2(n+3))))-1;
tree = new int[base*2+2];
val = new int[base+2];
for(int lx = 0;lx < base+1;lx++) val[lx] = 10000000, tree[base+lx+1]=lx;
for(int lx = 0;lx < n;lx++) val[lx+1] = _val[lx];
for(int lx = base;lx;lx--){
if(val[tree[lx*2]] > val[tree[lx*2+1]])
tree[lx]= tree[lx*2+1];
else
tree[lx] = tree[lx*2];
}
return;
}
int query(int a, int b){
a++, b++;
int ret = tree[base+a+1];
for(a+=base, b+=base+2;a^b^1;a>>=1, b>>=1){
if(~a&1) if(val[ret] > val[tree[a^1]]) ret = tree[a^1];
if(b&1) if(val[ret] > val[tree[b^1]]) ret = tree[b^1];
}
return ret-1;
}
private:
int base;
int* tree;
int* val;
};
VVI graph;
int ptr;
int hei[1000000];
int arr[2000000];
int harr[2000000];
int ac1[1000000];
int ato[1000000][2];
bst* lca_bst;
void dfs(int fa, int nd){
//printf("visit at %d\n", nd+1);
ac1[nd] = fa, hei[nd] = hei[fa]+1;
ato[nd][0] = ptr; arr[ptr] = nd; harr[ptr++] = hei[nd];
for(int lx = 0;lx < graph[nd].size();lx++){
if(graph[nd][lx] == fa) continue;
dfs(nd, graph[nd][lx]);
ato[nd][1] = ptr; arr[ptr] = nd; harr[ptr++] = hei[nd];
}
ato[nd][1] = ptr; arr[ptr] = nd; harr[ptr++] = hei[nd];
return;
}
void build(int n){
ptr = 0, hei[0] = 0; dfs(0, 0);
/*for(int lx = 0;lx < ptr;lx++)
printf("%02d ", arr[lx]+1);
puts("");
for(int lx = 0;lx < ptr;lx++)
printf("%02d ", harr[lx]);
puts("");*/
lca_bst = new bst(ptr, harr);
return;
}
int lca(int a, int b){
int ll = min(ato[a][0], ato[b][0]), rr = max(ato[a][1], ato[b][1]);
//printf("qlca(%d %d)\n", ll, rr);
return arr[lca_bst->query(ll, rr)];
}
struct path{ int h, x, y; }paths[2000000];
bool operator<(path a, path b){return a.h<b.h;}
int fat[1000000];
int qfat(int a){return fat[a]==a ? a: fat[a]=qfat(fat[a]);}
void join(int a, int b){fat[qfat(a)] = qfat(b); return;}
int main(){
int n, m, k, q;
scanf("%d %d %d %d", &n, &m, &k, &q);
graph = VVI(n, VI());
for(int lx = 0;lx < m;lx++){
int a, b; scanf("%d %d", &a, &b); a--, b--;
graph[a].push_back(b);
graph[b].push_back(a);
//printf("%d <-> %d\n", a+1, b+1);
}
build(n);
for(int lx = 0;lx < k;lx++){
int a, b; scanf("%d %d", &a, &b); a--, b--;
int c = lca(a, b);
//printf("lca(%d %d) = %d\n", a+1, b+1, c+1);
paths[lx*2].h = paths[lx*2+1].h = hei[c];
paths[lx*2].x = paths[lx*2+1].x = c;
paths[lx*2].y = a, paths[lx*2+1].y = b;
}
sort(paths, paths+2*k);
for(int lx = 0;lx < n;lx++) fat[lx] = lx;
for(int lx = 0;lx < 2*k;lx++){
int y = paths[lx].y, x = paths[lx].x;
while(qfat(y) != qfat(x)){
join(x, y);
y = ac1[y];
}
}
while(q--){
int a, b; scanf("%d %d", &a, &b); a--, b--;
printf("%d\n", qfat(a) == qfat(b));
}
return 0;
}
沒有留言:
張貼留言