Rikka with Intersection of Paths 题解
Rikka有一棵包含n(1≤n≤3×10^5)个节点的树 T,节点编号为 1 到 n 。树上也标记了m (2≤m≤3×10^5) 条简单路径,第i条路径连接 x_i 和 y_i (1≤x_i,y_i≤n),这些路径可能会有重复的。如果要在其中选择 k 条路径,要求这 k 条路径至少有一个公共点,计算有多少种选择方案。输出其模 10^9+7 的结果。
该题目有一个性质,就是
一个树上任意两条路径如果有交点的话,那么这些交点中肯定有一个为两条路径中的一条路径两端点的 LCA
那么当两条路径相交的一种情况,我们只在这个为 LCA 的节点上进行统计,其他相交的点并不进行统计
这样就不会重复统计啦!
令覆盖某节点的路径有 M 条,覆盖本节点且本届点是这条路径两端点的LCA的路径有 N 条
那么我们就要从 **覆盖某节点的 M 条路径 选出 K 条路径,且该节点至少是一条路径的 LCA **
利用容斥原理:
从覆盖某节点的 M 条路径中选出 K 条,为 C_M^K
从覆盖某节点的 M 条路径中选出路径两端点的 LCA 必不是本节点的 K 条,为 C^K_{M-N}
那么 C_M^K-C^K_{M-N} 即为覆盖某节点的 M 条路径 选出 K 条路径,且该节点至少是一条路径的 LCA 为这个节点的方案数
然后结合求 LCA 和树上差分,就可以完成统计
(顺便吐槽一下求逆元,要注意 n 的逆元和 n! 逆元的区别
Code:
#include<bits/stdc++.h>
#define ll long long
const ll mod = 1e9+7;
const int N = 3e5+10;
int T,cnt,last[N],n,m,k;
int fa[N][26],d[N];
int sz[N],ann[N];
ll ans=0,inv[N],fac[N];
struct edge{
int next,to;
}e[N*2];
void add(int a,int b){
e[++cnt].next=last[a],last[a]=cnt;
e[cnt].to=b;
}
ll qpow(int a,int b){
if(b==0) return 1;
ll ans=qpow(a,b/2);
ans=1ll*ans*ans%mod;
if(b&1) ans=1ll*ans*a%mod;
return ans;
}
void dfs(int x,int f){
d[x]=d[f]+1;fa[x][0]=f;
for(int i=last[x];i;i=e[i].next){
if(e[i].to==f) continue;
dfs(e[i].to,x);
}
}
int lca(int a,int b){
if(d[a]<d[b]) std::swap(a,b);
for(int i=25;i+1;--i){
if(d[fa[a][i]]>=d[b])
a=fa[a][i];
}
if(a==b) return a;
for(int i=25;i+1;--i){
if(fa[a][i]!=fa[b][i])
a=fa[a][i],b=fa[b][i];
}
return fa[a][0];
}
void bz(){
for(int i=1;i<=25;++i){
for(int j=1;j<=n;++j){
fa[j][i]=fa[fa[j][i-1]][i-1];
}
}
}
void dfs2(int x,int f){
for(int i=last[x];i;i=e[i].next){
if(e[i].to==f) continue;
dfs2(e[i].to,x);
sz[x]+=sz[e[i].to];
}
}
ll C(int a,int b){
if(a<b) return 0;
if(a==b||b==0) return 1;
//printf("ans %d %d\n",a,b);
ll ans=(1ll*fac[a]*inv[b]%mod)*inv[a-b]%mod;
return ans;
}
int main(){
scanf("%d",&T);
inv[1]=1,fac[1]=1;
for(int i=2;i<=300000;++i){
fac[i]=(1ll*fac[i-1]*i)%mod;
}
inv[300000]=qpow(fac[300000],mod-2);
//printf("%d\n",qpow(2,3));
for(int i=300000-1;i;--i)
inv[i]=(1ll*inv[i+1]*(i+1))%mod;
while(T--){
memset(last,0,sizeof(last));
memset(fa,0,sizeof(fa));
memset(sz,0,sizeof(sz));
memset(ann,0,sizeof(ann));
memset(d,0,sizeof(d));
cnt=0;ans=0;;
scanf("%d %d %d",&n,&m,&k);
for(int i=1;i<=n-1;++i){
int a,b;
scanf("%d %d",&a,&b);
add(a,b),add(b,a);
}
dfs(1,0);
bz();
while(m--){
int a,b;
scanf("%d %d",&a,&b);
int an=lca(a,b);
//printf("an%d\n",an);
ann[an]++;
sz[a]++,sz[b]++,sz[an]--,sz[fa[an][0]]--;
}
dfs2(1,0);
// for(int i=1;i<=n;++i)
// printf("inv:%d\n",inv[i]);
for(int i=1;i<=n;++i){
ans=(1ll*ans+C(sz[i],k)-1ll*C(sz[i]-ann[i],k)+mod)%mod;
}
printf("%lld\n",ans);
}
return 0;
}
0 条评论