1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| #include<bits/stdc++.h> using namespace std;
template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;} template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);}
typedef long long LL; typedef vector<int> vi; typedef pair<int, int> pii; #define mp(x, y) make_pair(x, y) #define pb(x) emplace_back(x)
const int N = 100005;
int n, v[N], k, root[N]; LL ans; vector<int> edge[N];
int sz[N], son[N], dep[N]; void dfs(int x, int d){ sz[x] = 1, son[x] = 0, dep[x] = d; for(auto &to : edge[x]){ dfs(to, d+1); sz[x] += sz[to]; if(!son[x] || sz[to] > sz[son[x]]) son[x] = to; } }
int tot; struct segTree{ int lson, rson, cnt; }tr[N*200]; inline void pushup(int rt){ tr[rt].cnt = tr[tr[rt].lson].cnt + tr[tr[rt].rson].cnt; } void add(int rt, int l, int r, int p, int val){ if(l == r){ tr[rt].cnt += val; return; } int mid = (l + r) >> 1; if(p <= mid){ if(!tr[rt].lson) tr[rt].lson = ++tot; add(tr[rt].lson, l, mid, p, val); } else{ if(!tr[rt].rson) tr[rt].rson = ++tot; add(tr[rt].rson, mid+1, r, p, val); } pushup(rt); } LL query(int rt, int l, int r, int L, int R){ if(rt == 0) return 0; if(L > R) return 0; if(l == L && r == R) return tr[rt].cnt; int mid = (l + r) >> 1; if(R <= mid) return query(tr[rt].lson, l, mid, L, R); else if(L > mid) return query(tr[rt].rson, mid+1, r, L, R); else return query(tr[rt].lson, l, mid, L, mid) + query(tr[rt].rson, mid+1, r, mid+1, R); }
void getAns(int x, int rtx){ if(2*v[rtx]-v[x] >= 0 && 2*v[rtx]-v[x] <= n) ans += query(root[2*v[rtx]-v[x]], 1, n, 1, min(n, k + 2 * dep[rtx] - dep[x])); for(auto &to : edge[x]) getAns(to, rtx); } void getData(int x){ add(root[v[x]], 1, n, dep[x], 1); for(auto &to : edge[x]) getData(to); } void delData(int x){ add(root[v[x]], 1, n, dep[x], -1); for(auto &to : edge[x]) delData(to); } void dsu(int x, bool opt){ for(auto &to : edge[x]){ if(to == son[x]) continue; dsu(to, true); } if(son[x]) dsu(son[x], false); for(auto &to : edge[x]){ if(to == son[x]) continue; getAns(to, x); getData(to); } if(opt){ for(auto &to : edge[x]) delData(to); } else add(root[v[x]], 1, n, dep[x], 1); }
int main(){ read(n, k); for(int i = 1; i <= n; i++) read(v[i]); for(int i = 2; i <= n; i++){ int f; read(f); edge[f].emplace_back(i); } dfs(1, 1); for(int i = 0; i <= n; i++) root[i] = ++tot; dsu(1, true); printf("%lld\n", ans * 2); return 0; }
|