0%

浅谈树链剖分

目录

1.树链剖分简介

2.树链剖分的实现

3.应用树链剖分

4.一些问题的解答

1.树链剖分简介

重要!前置知识

1.dfs序

2.线段树

3.链式前向星

嗯,暂时就这几点


回归正题,树链剖分是什么?

树链剖分是将一棵树(不一定是二叉树)分成一条链,每一个点只属于一条链的算法,然后用数据结构(线段树,树状数组)进行维护,这里主要使用线段树维护。

树链剖分能做什么?

假设给你几个操作:

1.将$ u $节点到 $v$ 节点的最短路所经过的节点加 $k$

2.求 $u$ 节点到 $v$ 节点的最短路所经过的节点的和

3.将 $x$ 节点的所有子树加上 $k$

4.求 $x$ 节点的所有子树的和

怎么做?

线段树?不行1、2操作处理不了。

树状数组?更不行,比线段树还差。

那怎么办?

树链剖分闪亮登场!

2.树链剖分的实现

思想就不讲了,直接看实现就可以了。

我们先看看一些术语

1.重儿子:以一个非叶子节点为根,其儿子中子树最多的

2.轻儿子:以一个非叶子节点为根,不是重儿子的就是轻儿子

3.重边:一个非叶子节点连接其重儿子的边

4.轻边:不是重边就是轻边

5.重链:由重边连接起来的链就是重链

思考一个小问题,为什么都是非叶子节点?

当然是叶子节点没有儿子啦

树链剖分就可以给每个节点重新编号,让每条重链的编号都是连续的。

我们需要维护一下几个变量(有点多)

1
2
3
4
5
6
7
8
9
int son[maxn];//重儿子
int top[maxn];//链头
int id[maxn];//新的编号
int a[maxn];//新编号后的权值
int siz[maxn];//节点的大小
int tot;//编到了几号
int dep[maxn];//深度(为被第几个dfs到的)
int fa[maxn];//节点的父亲
//主要变量就是上面几个了

光看变量可能不理解,上图:

重儿子(红色的是其父亲的重儿子,蓝色的是重边):

重新编号(橙色的是新编号):

节点大小(绿色的是每个点的大小,去掉了编号):

本人画图技巧差,如有更好的欢迎提供!

由图片,可以看出:

siz=子树个数+1

son=siz最大的儿子

其他不多说。

来模拟一下标号的过程

开始在节点1,表上1号

接下来在节点1的重儿子节点4,表上2号

接下来在节点4的重儿子节点7,标上3号

接下来在节点7的重儿子节点8,标上4号

接下来回溯到节点7,再到给节点7的轻儿子9节点9,标上5号

以此类推

树链剖分的实现

树链剖分的核心代码是两个$dfs$:

$dfs1$维护的变量:son,siz,dep,fa

$dfs2$维护的变量:id,a,top

给出代码:

第一步

第一步就是$dfs1$,上面说过作用了,不多说。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int dfs1(int now,int f,int d)//now是现在的节点,f是now的父亲节点,d是现在的深度
{
dep[now]=d;
fa[now]=f;
siz[now]=1;//大小初始为1,也就是本身
int ma=-1;//ma为现在儿子中最大的siz
for(int i=head[now];i;i=e[i].next)//链式前向星存边
{
if(e[i].v==f)//这样就是标记过了,因为是无向边
continue;
siz[now]+=dfs1(e[i].v,now,d+1);//自己的大小=自己所有儿子的大小
if(siz[e[i].v]>ma)//这里是找重儿子
ma=siz[e[i].v],son[now]=e[i].v;
}
return siz[now];
}

第二步

第二步就是$dfs2$了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void dfs2(int now,int t)//t为现在的链头
{
id[now]=++tot;//编号
a[tot]=b[now];
top[now]=t;//链头
if(!son[now])//这是个叶子节点,就不需要继续执行了
return;
dfs2(son[now],t);//优先标记重儿子
for(int i=head[now];i;i=e[i].next)//遍历每一个儿子
{
if(e[i].v==fa[now]||e[i].v==son[now])//这就代表处理过了,跳过
continue;
dfs2(e[i].v,e[i].v);//每个轻儿子,是单独的一条链
}
}

第三步

线段树,不多说

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
void build(int l,int r,int rt)//建树
{
if(l==r)
{
sum[rt]=a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];

}
void pushdown(int rt,int l,int r)//懒标记
{
if(lazy[rt]>0)
{
int mid=l+r>>1;
lazy[rt<<1]=(lazy[rt<<1]+lazy[rt]);
lazy[rt<<1|1]=(lazy[rt<<1|1]+lazy[rt]);
sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*lazy[rt]);
sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*lazy[rt]);
lazy[rt]=0;
}
}
void update(int l,int r,int rt,int i,int j,int k)//区间修改
{
if(i<=l&&r<=j)
{
lazy[rt]+=k;
sum[rt]+=(r-l+1)*k;

return;
}
pushdown(rt,l,r);
int mid=l+r>>1;
if(i<=mid)
update(l,mid,rt<<1,i,j,k);
if(mid<j)
update(mid+1,r,rt<<1|1,i,j,k);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
int query(int l,int r,int rt,int i,int j)//区间查询和
{
if(i<=l&&r<=j)
{
return sum[rt];
}
pushdown(rt,l,r);
int ans=0,mid=l+r>>1;
if(i<=mid)
ans+=query(l,mid,rt<<1,i,j);

if(mid<j)
ans+=query(mid+1,r,rt<<1|1,i,j);
return ans;
}

第四步

这里是求 $u$ 节点到 $v$ 节点的最短路所经过的节点的和

这是树链剖分的一个难点,我们慢慢解释:

首先,指定两个节点,让它们一直“跳”,直到在同一条链。

怎么“跳”?

别忘了$dfs2$,这保证了编号的连续。

我们可以让深度浅的不动,让深度深的往上跳,每一次呢,跳到链头的父亲,直到在同一条链上。

每次跳的时候,ans+一开始的编号到链头,用线段树求连续和就行啦。

跳完后,直接让ans+链头(编号)到编号的和即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int treesum(int u,int v)
{
int s=0;
int tu=top[u];
int tv=top[v];
while(tu!=tv){
if(dep[tu]<dep[tv]){
swap(tu,tv);
swap(u,v);
}
s+=query(1,n,1,id[tu],id[u]);
u=fa[tu];
tu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
s+=query(1,n,1,id[u],id[v]);
return s;
}

第五步

将 $x$ 节点的所有子树加上 $k$

和上一步一样,只是求和改成修改

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void change(int u,int v,int x)
{
int tu=top[u];
int tv=top[v];
while(tu!=tv){
if(dep[tu]<dep[tv]){
swap(tu,tv);
swap(u,v);
}
update(1,n,1,id[tu],id[u],x);
u=fa[tu];
tu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
update(1,n,1,id[u],id[v],x);

}

完整代码

题目

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include<bits/stdc++.h>
#define inf 2147283647
#define int long long
using namespace std;
const int maxn=2000005;
int n,m,r,p,lazy[maxn<<2],cnt,dep[maxn],b[maxn],id[maxn],son[maxn],sum[maxn<<2],fa[maxn],siz[maxn],a[maxn],top[maxn];
struct node{
int u,v,next;
}e[maxn];
int head[maxn],tot=0;
void add(int u,int v)
{
e[++cnt].u=u;
e[cnt].v=v;
e[cnt].next=head[u];
head[u]=cnt;
}
int dfs1(int now,int f,int d)
{
dep[now]=d;
fa[now]=f;
siz[now]=1;
int ma=-1;
for(int i=head[now];i;i=e[i].next)
{
if(e[i].v==f)
continue;
siz[now]+=dfs1(e[i].v,now,d+1);
if(siz[e[i].v]>ma)
ma=siz[e[i].v],son[now]=e[i].v;
}
return siz[now];
}
void dfs2(int now,int t)
{
id[now]=++tot;
a[tot]=b[now];
top[now]=t;
if(!son[now])
return;
dfs2(son[now],t);
for(int i=head[now];i;i=e[i].next)
{
if(!id[e[i].v])
dfs2(e[i].v,e[i].v);
}
}
void pushdown(int rt,int l,int r)
{
if(lazy[rt]>0)
{
int mid=l+r>>1;
lazy[rt<<1]=(lazy[rt<<1]+lazy[rt])%p;
lazy[rt<<1|1]=(lazy[rt<<1|1]+lazy[rt])%p;
sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*lazy[rt])%p;
sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*lazy[rt])%p;
lazy[rt]=0;
}
}
void build(int l,int r,int rt)
{
if(l==r)
{
sum[rt]=a[l]%p;
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
sum[rt]%=p;
}
void update(int l,int r,int rt,int i,int j,int k)
{
if(i<=l&&r<=j)
{
lazy[rt]+=k;
lazy[rt]%=p;
sum[rt]+=(r-l+1)*k;
sum[rt]%=p;

return;
}
pushdown(rt,l,r);
int mid=l+r>>1;
if(i<=mid)
update(l,mid,rt<<1,i,j,k);
if(mid<j)
update(mid+1,r,rt<<1|1,i,j,k);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
sum[rt]%=p;
}
int query(int l,int r,int rt,int i,int j)
{
if(i<=l&&r<=j)
{

return sum[rt]%p;
}
pushdown(rt,l,r);
int ans=0,mid=l+r>>1;
if(i<=mid)
ans+=query(l,mid,rt<<1,i,j);
ans%=p;
if(mid<j)
ans+=query(mid+1,r,rt<<1|1,i,j);
return ans%p;
}
void change(int u,int v,int x)
{
int tu=top[u];
int tv=top[v];
while(tu!=tv){
if(dep[tu]<dep[tv]){
swap(tu,tv);
swap(u,v);
}
update(1,n,1,id[tu],id[u],x);
u=fa[tu];
tu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
update(1,n,1,id[u],id[v],x);

}
int treesum(int u,int v)
{
int s=0;
int tu=top[u];
int tv=top[v];
while(tu!=tv){
if(dep[tu]<dep[tv]){
swap(tu,tv);
swap(u,v);
}
s+=query(1,n,1,id[tu],id[u]);
s%=p;
u=fa[tu];
tu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
s+=query(1,n,1,id[u],id[v]);
return s%p;
}
signed main(){
cin>>n>>m>>r>>p;
for(int i=1;i<=n;i++)
cin>>b[i];
for(int i=1;i<=n-1;i++)
{
int u,v;
cin>>u>>v;
add(u,v);
add(v,u);
}
dfs1(r,0,1);
dfs2(r,r);
build(1,n,1);
for(int i=1;i<=m;i++)
{
int c,x,y,z;
cin>>c>>x;
if(c==1)
cin>>y>>z;
if(c==2||c==3)
cin>>y;
z%=p;
if(c==1)
change(x,y,z);
if(c==2)
cout<<treesum(x,y)<<endl;
if(c==3)
update(1,n,1,id[x],id[x]+siz[x]-1,y);//全部子树修改就是从它到它+它的大小-1,因为编号是连续的
if(c==4)//同上
cout<<query(1,n,1,id[x],id[x]+siz[x]-1)<<endl;
}
return 0;
}

大概就是这样了,时间复杂度为$O((log_2^n)^2)$,空间复杂度为$O(n)$。

3.应用树链剖分

题目

分析:

经典的树链剖分题,但题面有点不易理解,具体可以看成这样:

每只奶牛回家时,就将它的家那个节点+1,这样代表有牛了

放慢的次数,就是区间查询和,从1~家的区间查询

所以是树链剖分+区间查询+单点修改

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include<bits/stdc++.h>
#define inf 2147283647
#define int long long
using namespace std;
const int maxn=500005;
int head[maxn],id[maxn],dep[maxn],fa[maxn],son[maxn],top[maxn],tot,cnt,n,m,siz[maxn],a[maxn],b[maxn],sum[maxn<<2],lazy[maxn<<2];
struct node{
int u,v,next;
}e[maxn];
void add(int u,int v)
{
e[++cnt].u=u;
e[cnt].v=v;
e[cnt].next=head[u];
head[u]=cnt;
}
int dfs1(int now,int f,int d)
{
dep[now]=d;
fa[now]=f;
siz[now]=1;
int ma=-1;
for(int i=head[now];i;i=e[i].next)
{
if(e[i].v==f)
continue;
siz[now]+=dfs1(e[i].v,now,d+1);
if(siz[e[i].v]>ma)
ma=siz[e[i].v],son[now]=e[i].v;
}
return siz[now];
}
void dfs2(int now,int t)
{
id[now]=++tot;
a[tot]=b[now];
top[now]=t;
if(!son[now])
return;
dfs2(son[now],t);
for(int i=head[now];i;i=e[i].next)
{
if(e[i].v==fa[now]||e[i].v==son[now])
continue;
dfs2(e[i].v,e[i].v);
}
}
void pushdown(int rt,int l,int r)
{
if(lazy[rt]>0)
{
int mid=l+r>>1;
lazy[rt<<1]=(lazy[rt<<1]+lazy[rt]);
lazy[rt<<1|1]=(lazy[rt<<1|1]+lazy[rt]);
sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*lazy[rt]);
sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*lazy[rt]);
lazy[rt]=0;
}
}
int query(int l,int r,int rt,int i,int j)
{
if(i<=l&&r<=j)
{
return sum[rt];
}
pushdown(rt,l,r);
int ans=0,mid=l+r>>1;
if(i<=mid)
ans+=query(l,mid,rt<<1,i,j);

if(mid<j)
ans+=query(mid+1,r,rt<<1|1,i,j);
return ans;
}
void build(int l,int r,int rt)
{
if(l==r)
{
sum[rt]=a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
;
}
void update(int l,int r,int rt,int i,int j,int k)//这里用的区间修改(当然单点修改也行),懒标记优化
{
if(i<=l&&r<=j)
{
lazy[rt]+=k;
sum[rt]+=(r-l+1)*k;
return;
}
pushdown(rt,l,r);
int mid=l+r>>1;
if(i<=mid)
update(l,mid,rt<<1,i,j,k);
if(mid<j)
update(mid+1,r,rt<<1|1,i,j,k);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
int treesum(int u,int v)
{
int s=0;
int tu=top[u];
int tv=top[v];
while(tu!=tv){
if(dep[tu]<dep[tv]){
swap(tu,tv);
swap(u,v);
}
s+=query(1,n,1,id[tu],id[u]);
u=fa[tu];
tu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
s+=query(1,n,1,id[u],id[v]);
return s;
}
signed main(){
// freopen("P3178_1.in","r",stdin);
cin>>n;
// cout<<n<<endl;
for(int i=1;i<=n-1;i++)
{
int u,v;
cin>>u>>v;
add(u,v);
add(v,u);
}
dfs1(1,0,1);
dfs2(1,1);
build(1,n,1);
for(int i=1;i<=n;i++)
{
int t;
cin>>t;
cout<<treesum(1,t)<<endl;
update(1,n,1,id[t],id[t],1);
}
return 0;
}

题目

分析:

经典的树链剖分模板题,根据每个操作,套模板即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include<bits/stdc++.h>
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define int long long
using namespace std;
const int maxn=500005;
int a[maxn],b[maxn],son[maxn],siz[maxn],id[maxn],top[maxn],sum[maxn<<2];
int fa[maxn],cnt,tot,n,m,head[maxn],dep[maxn],lazy[maxn<<2];
struct node{
int u,v,next;
}e[maxn];
void add(int u,int v)
{
e[++cnt].u=u;
e[cnt].v=v;
e[cnt].next=head[u];
head[u]=cnt;
}
int dfs1(int now,int f,int d)
{
fa[now]=f;
dep[now]=d;
siz[now]=1;
int ma=-1;
for(int i=head[now];i;i=e[i].next)
{
if(e[i].v==f)
continue;
siz[now]+=dfs1(e[i].v,now,d+1);

if(siz[e[i].v]>ma)
ma=siz[e[i].v],son[now]=e[i].v;
}
return siz[now];
}
void dfs2(int now,int t)
{
id[now]=++tot;
top[now]=t;
a[tot]=b[now];
if(!son[now])
return;
dfs2(son[now],t);
for(int i=head[now];i;i=e[i].next)
if(!id[e[i].v])
dfs2(e[i].v,e[i].v);
}
void build(int l,int r,int rt)
{
if(l==r)
{
sum[rt]=a[l];
return;
}
int mid=l+r>>1;
build(lson);
build(rson);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void pushdown(int rt,int l,int r)
{
if(lazy[rt])
{
int mid=l+r>>1;
lazy[rt<<1]+=lazy[rt];
lazy[rt<<1|1]+=lazy[rt];
sum[rt<<1]+=(mid-l+1)*lazy[rt];
sum[rt<<1|1]+=(r-mid)*lazy[rt];
lazy[rt]=0;
}
}
void update(int l,int r,int rt,int i,int j,int k)
{
// cout<<l<<' '<<r<<" "<<i<<' '<<j<<endl;
if(i<=l&&r<=j)
{
sum[rt]+=(r-l+1)*k;
lazy[rt]+=k;
return;
}
pushdown(rt,l,r);
int mid=l+r>>1;
if(i<=mid)
update(lson,i,j,k);
if(mid<j)
update(rson,i,j,k);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
int query(int l,int r,int rt,int i,int j)
{
if(i<=l&&r<=j)
{
return sum[rt];
}
pushdown(rt,l,r);
int mid=l+r>>1,ans=0;
if(i<=mid)
ans+=query(lson,i,j);
if(mid<j)
ans+=query(rson,i,j);
return ans;
}
int qsum(int u,int v)
{
int s=0;
int tu=top[u];
int tv=top[v];
while(tu!=tv){
if(dep[tu]<dep[tv]){
swap(tu,tv);
swap(u,v);
}

s+=query(1,n,1,id[tu],id[u]);
u=fa[tu];
tu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
s+=query(1,n,1,id[u],id[v]);
return s;
}
signed main(){
cin>>n>>m;
for(int i=1;i<=n;i++)
cin>>b[i];
for(int i=1;i<=n-1;i++)
{
int u,v;
cin>>u>>v;
add(u,v);
add(v,u);
}
dfs1(1,1,1);
dfs2(1,1);
build(1,n,1);
for(int i=1;i<=m;i++)
{
int c,x,a;
cin>>c>>x;
if(c!=3)
cin>>a;
if(c==1)
update(1,n,1,id[x],id[x],a);
if(c==2)
update(1,n,1,id[x],id[x]+siz[x]-1,a);
if(c==3)
cout<<qsum(1,x)<<endl;

}
return 0;
}

题目

题目

上面两道题尝试自己完成

4.一些问题的解答

暂时还没有呢,欢迎提问!

欢迎打赏~.