求大佬看看树链剖分模板哪里错了

回复帖子

@大雾山上 2019-05-15 22:52 回复
#include<bits/stdc++.h>
using namespace std;
const int M=1000001;
long long n,m,k,tot,cnt,MOD,a[M];
int head[M],size[M],fa[M],dep[M];
int son[M],seg[M],rev[M],top[M];
struct Edge
{
    int next,to;
}edge[1000001];
struct Tree
{
    int l,r;
    long long add,sum;
    #define l(x) tree[x].l
    #define r(x) tree[x].r
    #define add(x) tree[x].add
    #define sum(x) tree[x].sum
}tree[4000001];
void addd(int from,int to)
{
    edge[++tot].next=head[from];
    edge[tot].to=to;
    head[from]=tot;
}
void build(int p,int l,int r)
{
    l(p)=l;
    r(p)=r;
    if(l==r) 
    {
        sum(p)=a[rev[l]];
        return;
    }
    int mid=(l+r)/2;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    sum(p)=sum(p*2)+sum(p*2+1);
}
void spread(int p)
{
    if(add(p))
    {
        sum(p*2)+=(long long)add(p)*(r(p*2)-l(p*2)+1);
        sum(p*2+1)+=(long long)add(p)*(r(p*2+1)-l(p*2+1)+1);
        add(p*2)+=add(p);
        add(p*2+1)+=add(p);
        add(p)=0;
    }
}
void change(int l,int r,int p,int d)
{
    if(l<=l(p)&&r>=r(p))
    {
        sum(p)+=d*(r(p)-l(p)+1);
        add(p)+=d;
        return;
    }
    spread(p);
    int mid=(l(p)+r(p))/2;
    if(l<=mid) change(p*2,l,r,d);
    if(r>mid) change(p*2+1,l,r,d);
    sum(p)=sum(p*2)+sum(p*2+1);
}
long long ask(int l,int r,int p)
{
    if(l<=l(p)&&r>=r(p)) return sum(p);
    spread(p);
    int mid=(l(p)+r(p))/2;
    long long ans=0;
    if(l<=mid) ans+=ask(p*2,l,r);
    if(r>mid) ans+=ask(p*2+1,l,r);
    return ans%MOD;
}
void dfs1(int u,int f)
{
    dep[u]=dep[f]+1;
    fa[u]=f;
    size[u]=1;
    for(int i=head[u];i;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=f)
        {
            dfs1(v,u);
            size[u]+=size[v];
            if(size[v]>size[son[u]])
            son[u]=v;
        }
    }
}
void dfs2(int u,int f)
{
    if(son[u])
    {
        seg[son[u]]=++cnt;
        rev[cnt]=son[u];
        top[son[u]]=top[u];
        dfs2(son[u],u);
    }
    for(int i=head[u];i;i=edge[i].next)
    {
        int v=edge[i].to;
        if(!top[v])
        {
            seg[v]=++cnt;
            rev[cnt]=v;
            top[v]=v;
            dfs2(v,u);
        }
    }
}
void get_init(int x,int y,int d)
{
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]<dep[fy]) 
        {
            swap(fx,fy);
            swap(x,y);
        }
        change(cnt,seg[fx],seg[fy],d);
        x=fa[x];
        fx=top[x];
    }
    if(dep[x]>dep[y])
    swap(x,y);
    change(cnt,seg[x],seg[y],d);
    return;
}
int get_sum(int x,int y)
{
    int ans=0;
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]<dep[fy]) 
        {
            swap(fx,fy);
            swap(x,y);
        }
        ans+=ask(cnt,seg[fx],seg[x]);
        x=fa[x];
        fx=top[x];
    }
    if(dep[x]>dep[y])
    swap(x,y);
    ans+=ask(cnt,seg[x],seg[y]);
    return ans%MOD;
}
int main()
{
    cin>>n>>m>>k>>MOD;
    for(int i=1;i<=n;i++) 
    cin>>a[i];
    for(int i=1;i<=n-1;i++)
    {
        int x,y;
        cin>>x>>y;
        addd(x,y);
        addd(y,x);
    }
    dfs1(k,k);
    cnt=seg[1]=top[1]=rev[1]=1;
    dfs2(k,k);
    build(1,1,cnt);
    for(int i=1;i<=m;i++)
    {
        int op;
        cin>>op;
        if(op==1)
        {
            int x,y,z;
            cin>>x>>y>>z;
            get_init(x,y,z);
        }
        if(op==2)
        {
            int x,y;
            cin>>x>>y;
            cout<<get_sum(x,y)<<endl;
        }
        if(op==3)
        {
            int x,y;
            cin>>x>>y;
            change(1,seg[x],seg[x]+size[x]-1,y);
        }
        if(op==4)
        {
            int x;
            cin>>x;
            cout<<ask(1,seg[x],seg[x]+size[x]-1)<<endl;
        }
    }
    return 0;
}
@CSJ1 2019-05-16 07:44 回复

DFS2的第二个参数不应该是f吗

@CSJ1 2019-05-16 13:00 回复

@大雾山上 你的DFS2不是有两个参数吗,其中第二个记的是链的顶端,所以你在向下遍历重儿子的时候还是f