Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /* how many pairs in a tree that constitute a number that is divisible by M */
- /* Cf - 716E */
- int Subtree[MAX],del[MAX],n,totalnode;
- ll inv[MAX],power[MAX];ll M;vector<pll>v[MAX];ll up[MAX],down[MAX],idx,deep[MAX];vi lst;ll ans;ll comp[MAX];
- void precal(){
- for(ll b = 1;b<=10;b++){
- if(b*M %10 !=1)continue;
- inv[1] =( 1 - b*M)/10 + M;
- inv[1] %= M;
- }
- for(int i = 2;i<MAX;i++)inv[i] = (inv[i-1]*inv[1]+M)%M;
- power[0] = 1;
- for(int i = 1;i<MAX;i++)power[i] = (power[i-1]*10)%M;
- }
- void dfs2(int s,int p)
- {
- Subtree[s] = 1;
- totalnode++;
- for(auto it : v[s])
- {
- if(it.ff==p||del[it.ff])continue;
- dfs2(it.ff,s);
- Subtree[s]+=Subtree[it.ff];
- }
- }
- int Getcenter(int s,int p)
- {
- for(auto it : v[s])
- {
- if(it.ff==p||del[it.ff])continue;
- if(Subtree[it.ff]>totalnode/2)return Getcenter(it.ff,s);
- }
- return s;
- }
- void dfs(int s,int p = -1,ll c = 0,int depth = 0){
- lst.pb(s);
- comp[s] = idx;
- deep[s] = depth;
- if(p!=-1){
- down[s] = (down[p] * 10 + c) % M;
- up[s] = c %M * power[depth-1] + up[p];
- up[s]%=M;
- }
- for(auto it : v[s]){
- if(it.ff==p||del[it.ff])continue;
- if(p==-1)idx++;
- dfs(it.ff,s,it.ss,depth+1);
- }
- }
- void Clear(int s){
- up[s] = 0;
- down[s] = 0;
- idx = 0;
- lst.clear();
- deep[s] = 0;
- dfs(s);
- }
- ll solve(int s,int p = -1 ,ll c = 0,int depth = 0){
- map<ll,ll>cnt;
- map<ll,ll>cmp[idx+50];
- for(auto it : lst){
- cnt[up[it]]++;
- cmp[comp[it]][up[it]]++;
- }
- ll ret = cnt[0] - 1;
- for(auto it : lst){
- if(it==s)continue;
- ll val = -down[it] %M * inv[deep[it]] % M + M;
- val %= M;
- ret+=cnt[val];
- ret-=cmp[comp[it]][val];
- }
- return ret;
- }
- void Decompose(int x,int p)
- {
- totalnode = 0;
- dfs2(x,p);
- int s = Getcenter(x,p);
- del[s] = 1;
- Clear(s);
- ans+=solve(s);
- ///cout<<ans<<endl;
- for(auto it : v[s])
- {
- if(it.ff==p||del[it.ff])continue;
- Decompose(it.ff,s);
- }
- }
- int main()
- {
- booster()
- ///read("input.txt");
- cin>>n>>M;
- precal();
- for(int i = 0;i<n-1;i++)
- {
- ll a,b,w;
- cin>>a>>b>>w;
- v[a].pb(pll(b,w));
- v[b].pb(pll(a,w));
- }
- Decompose(0,-1);
- cout<<ans;
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement