浅谈树链剖分

引子

在OI中,有时候我们会需要处理一些树上的链的问题

比方说,给定一棵$n$个点的树,$m$个操作,每次查询$x$和$y$之间的链上的和

不要在意到底是什么操作,这真的只是个引子

考虑做法。

  1. level 1

    $ n,m \leq 100 $

    那么我们直接暴力求即可。复杂度$O(nm)$

  2. level 2

    $ n\leq 1000 , m \leq 100000$

    很明显不能直接暴力了。

    询问过多,但是并没有修改操作,所以可以考虑$O(n^2)$把$x$和$y$之间的和预处理出来,存起来,直接访问。

  3. level 3

    $ n\leq10^5,m\leq10^5$,并且要求支持修改操作

    这才是我们今天要讨论的问题

    为了支持$O(log^2n)$的修改,$O(log^2n)$的查询,我们发展了这种叫树链剖分的东西。

树链剖分

前置芝士

  • DFS

  • 线段树

  • 链式前向星

定义

  • 重儿子:一个点有多个儿子,定义其儿子中子树大小最大的儿子为重儿子。

  • 轻儿子:一个点不是重儿子的儿子都是轻儿子。

  • 重边:一个点与其重儿子之间的边。

  • 轻边:一个点与其轻儿子之间的边。

  • 重链:完全由重边连成的链。

  • 重链的顶端:一条重链上深度最小(最靠近根)的点。

  • 特别的,我们为了代码的舒适性人为定义一个轻儿子为一条长度为1的重链。

树链剖分能做什么?

  • 解决一条链上的信息查询/修改问题

  • 其他链上线段树能维护的东西

实现

存树

这个实际上不难,所有数字直接先丢进线段树的数组,边的话直接读入的时候链式前向星存下来就好。

考虑到是双向边,要加两次。

1
2
3
4
5
6
7
8
9
10
11
12
13
inline void add(int u,int v){//链式前向星加边
to[++bian]=v;//记下现在第bian条边所指向的节点
nxt[bian]=beg[u];//指针指向u的链表原来的表头
beg[u]=bian;//更新表头
}
scanf ("%lld%lld%lld%lld",&n,&m,&r,&mod);//读入
fa[r]=0,dep[r]=1;//r是根节点,根节点的父亲是0,深度的1
for (int i=1;i<=n;i++)scanf ("%lld",&tree.a[i]);
for (int i=1;i<n;i++){
ll a,b;
scanf("%lld%lld",&a,&b);
add(a,b),add(b,a);
}

两遍DFS

第一遍DFS

这遍DFS要处理出以下的信息:

  • 一个点的深度$d$

  • 一个点的重儿子$s$

  • 一个点的父亲$f$

  • 一个点的子树大小(含自己)$siz$

代码:

1
2
3
4
5
6
7
8
9
10
11
inline void dfs1(int x){
siz[x]=1;//初始大小是1(只有自己)
for (int i=beg[x];i;i=nxt[i]){//访问x的所有出边
if (to[i]==fa[x])continue;//到父亲的边,不考虑
fa[to[i]]=x;//访问到的节点的父亲是x
dep[to[i]]=dep[x]+1;//深度是x+1
dfs1(to[i]);//向访问到的节点DFS
siz[x]+=siz[to[i]];//加上儿子的子树大小
if (siz[to[i]]>siz[son[x]])son[x]=to[i];//更新重儿子的信息
}
}

第二遍DFS

在第一遍DFS的基础上,我们现在知道了每个点的子树大小(包括自己),重儿子等信息。

接下来就是树链剖分的核心了:

把一棵树按照轻重边剖分成若干条链,剖分的过程就是第二遍DFS

至于剖分的原因,后面在证明复杂度的时候会说

具体实现:

我们对每个点重标号,使得一条重链上的点的标号是连续的,然后对重标号后的点建线段树

第二遍DFS需要处理这些内容:

  • 记录下每个点的新标号

  • 把这个点的值赋到新标号上(之后建线段树要用)

  • 记录下每个点所在的重链的顶端

代码:

1
2
3
4
5
6
7
8
inline void dfs2(int x,int y){
top[x]=y;id[x]=++tot;num[tot]=x;//记录下x所在的重链的顶端y,同时为新标号赋值
if (son[x])dfs2(son[x],y);//优先DFSx的重儿子,这样保证一条重链上的点的标号是连续的
for (int i=beg[x];i;i=nxt[i]){
if (to[i]==fa[x]||to[i]==son[x])continue;
dfs2(to[i],to[i]);//DFS剩下的轻儿子
}
}

处理

敲黑板:重点来了!也许你不清楚前文提到的两边DFS的意义,这里会有解释

我们进行了第二遍DFS之后,得到了一下的结果:

  • 由于我们DFS时优先考虑重儿子,这样每条重链上的点的新标号是连续的

  • 由于是DFS,每棵子树的新标号是连续的

链的查询/修改

有了上文,本来链上的问题变成了统计若干条重链加轻边的问题。

查询和修改的思路实际上很相似,都是两个点分别向上跳,直到在同一条重链上为止,最后直接统计一次重链上两个点之间的和

既然重链上的新标号的连续的,那么我们就可以用线段树维护每条重链的和,这样重链的查询就是$O(logn)$的了。

至于轻边,我们可以直接暴力累加,但由于我们人为定义了每个轻儿子是一条长度为1的重链,所以实际上写代码的时候可以和重链的查询合并。

修改的代码:

1
2
3
4
5
6
7
8
9
10
void chge(int x,int y,int k){
k%=mod;
while (top[x]!=top[y]){//使用类似倍增求LCA的思想,每次找深度大的点往上跳
if (dep[top[x]]<dep[top[y]])swap(x,y);//保证x的深度大
tree.change(1,id[top[x]],id[x],1,n,k);//直接统计x所在的重链的和,这条重链的起点是x所在的重链的顶端的新标号,终点是x的新标号
x=fa[top[x]];//直接跳到x所在重链顶端的父亲,保证不重复统计
}//把两个点跳到同一条重链上
if (dep[x]<dep[y])swap(x,y);
tree.change(1,id[y],id[x],1,n,k);//最后处理重链上两个点之间的部分
}

查询的代码:

1
2
3
4
5
6
7
8
9
10
11
int query(int x,int y){
int ans=0;
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]])swap(x,y);//同样是选深度大的向上跳
ans=(ans+tree.ask(1,id[top[x]],id[x],1,n))%mod;
x=fa[top[x]];
}
if (dep[x]<dep[y])swap(x,y);
ans=(ans+tree.ask(1,id[y],id[x],1,n))%mod;//也是最后处理同一条重链上x和y之间的部分
return ans;
}

实际上很像,只是把线段树的修改换成了查询而已。

子树的修改/查询

既然有了一棵子树的新标号的连续的保证,那么原来的一棵子树实际上对应着线段树中的一段连续区间。

那么直接线段树区间修改/查询就好。而且由于是连续区间,那么$x$的子树的起点必然是$x$的新标号,终点必然是$x$的新标号$+x$的子树大小再$-1$。

那么代码就很简单了:

修改:

1
tree.change(1,id[x],id[x]+siz[x]-1,1,n,k)

查询:

1
tree.ask(1,id[x],id[x]+siz[x]-1,1,n)

复杂度简单证明

当树为二叉树的时候,深度最大。

一棵树总共有$n$个点,由于一个点$x$的重儿子起码占了$x$的子树大小的一半,那么每次递归下去节点个数都减半,很明显$logn$次就能结束。

这里特殊的是链的情况,然而整棵树上有一大部分是链的时候,显然他们必然连成若干条重链,而不可能都是轻边,那么对我们的复杂度并不构成巨大的影响。

所以重链的数量$\leq logn$。

又因为每两条重链之间必然由轻边分割(不然就连成一条重链了),所以轻边的数量同样$\leq logn$。

于是树链剖分做到了对于一条链上的查询/修改,$O(log^2n)$的复杂度。

很是优秀了。

完整代码

实际上讲到这里应该都会写了吧(雾,但为了讲清楚一些细节还是给一下吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <bits/stdc++.h>
using namespace std;
const int N=100005;
#define ll long long
int dep[N],siz[N],fa[N],z[N],to[N<<1],beg[N<<1],nxt[N<<1],top[N<<2],bian,son[N<<1],id[N<<1],tot,n,m,r,mod,num[N];
struct Tree{
ll ans[N<<2],tag[N<<2],a[N];
inline ll lson(ll p){return p<<1;}
inline ll rson(ll p){return (p<<1)|1;}
inline void push_up(ll p){ans[p]=(ans[lson(p)]+ans[rson(p)])%mod;}
inline void build(ll p,ll l,ll r){
if (l==r){ans[p]=a[num[l]];return ;}
ll mid=(l+r)>>1;
build(lson(p),l,mid);
build(rson(p),mid+1,r);
push_up(p);
tag[p]=0;
}
inline void lazy_tag(ll p,ll l,ll r,ll k){ans[p]=(ans[p]+(r-l+1)*k)%mod,tag[p]=(tag[p]+k)%mod;}
inline void push_down(ll p,ll l,ll r){
ll mid=(l+r)>>1;
lazy_tag(lson(p),l,mid,tag[p]);
lazy_tag(rson(p),mid+1,r,tag[p]);
tag[p]=0;
}
inline void change(ll p,ll nl,ll nr,ll l,ll r,ll k){//nl,nr->changing l,changing r;l,r->visiting l,visiting r
if (nl<=l&&nr>=r){ans[p]=(ans[p]+(r-l+1)*k)%mod,tag[p]=(tag[p]+k)%mod;return ;}
ll mid=(l+r)>>1;
push_down(p,l,r);
if (nl<=mid)change(lson(p),nl,nr,l,mid,k);
if (nr>mid)change(rson(p),nl,nr,mid+1,r,k);
push_up(p);
}
inline ll ask(ll p,ll nl,ll nr,ll l,ll r){
if (nl<=l&&nr>=r)return ans[p];
ll mid=(l+r)>>1,res=0;
push_down(p,l,r);
if (nl<=mid)res=(res+ask(lson(p),nl,nr,l,mid))%mod;
if (nr>mid)res=(res+ask(rson(p),nl,nr,mid+1,r))%mod;
return res;
}
}tree;//之前封装好的线段树
inline void add(int u,int v){//链式前向星加边
to[++bian]=v;
nxt[bian]=beg[u];
beg[u]=bian;
}
inline void dfs1(int x){//第一遍dfs,之前已经详细讲过
siz[x]=1;
for (int i=beg[x];i;i=nxt[i]){
if (to[i]==fa[x])continue;
fa[to[i]]=x;
dep[to[i]]=dep[x]+1;
dfs1(to[i]);
siz[x]+=siz[to[i]];
if (siz[to[i]]>siz[son[x]])son[x]=to[i];
}
}
inline void dfs2(int x,int y){//同上
top[x]=y;id[x]=++tot;num[tot]=x;
if (son[x])dfs2(son[x],y);
for (int i=beg[x];i;i=nxt[i]){
if (to[i]==fa[x]||to[i]==son[x])continue;
dfs2(to[i],to[i]);
}
}
int query(int x,int y){//链的查询,也讲过了
int ans=0;
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]])swap(x,y);
ans=(ans+tree.ask(1,id[top[x]],id[x],1,n))%mod;
x=fa[top[x]];
}
if (dep[x]<dep[y])swap(x,y);
ans=(ans+tree.ask(1,id[y],id[x],1,n))%mod;
return ans;
}
void chge(int x,int y,int k){//链的修改
k%=mod;
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]])swap(x,y);
tree.change(1,id[top[x]],id[x],1,n,k);
x=fa[top[x]];
}
if (dep[x]<dep[y])swap(x,y);
tree.change(1,id[y],id[x],1,n,k);
}
int main(){
scanf ("%lld%lld%lld%lld",&n,&m,&r,&mod);fa[r]=0,dep[r]=1;
for (int i=1;i<=n;i++)scanf ("%lld",&tree.a[i]);
for (int i=1;i<n;i++){
ll a,b;
scanf("%lld%lld",&a,&b);
add(a,b),add(b,a);
}
dfs1(r);//从根节点开始
dfs2(r,r);//同上,根节点必然是根所在的重链的顶端
tree.build(1,1,n);//对于我们重标号之后的数组,建线段树
while(m--){
ll flag,x,y,k;
scanf ("%lld",&flag);
switch (flag){
case 1:scanf ("%lld%lld%lld",&x,&y,&k);chge(x,y,k);break;//链上修改
case 2:scanf ("%lld%lld",&x,&y);printf("%lld\n",query(x,y));break;//链上查询
case 3:scanf ("%lld%lld",&x,&k);tree.change(1,id[x],id[x]+siz[x]-1,1,n,k);break;//子树修改
case 4:scanf ("%lld",&x);printf("%lld\n",tree.ask(1,id[x],id[x]+siz[x]-1,1,n));//子树查询
}
}
}