Advertisement
Saleh127

RUET IUPC Problem J / NNT + Divide and Conquer

Aug 15th, 2022
874
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.57 KB | None | 0 0
  1. /***
  2.  created: 2022-08-14-23.16.33
  3. ***/
  4.  
  5. #include <bits/stdc++.h>
  6. #include <ext/pb_ds/assoc_container.hpp>
  7. #include <ext/pb_ds/tree_policy.hpp>
  8. using namespace std;
  9. using namespace __gnu_pbds;
  10. template<typename U> using ordered_set=tree<U, null_type,less<U>,rb_tree_tag,tree_order_statistics_node_update>;
  11. #define ll long long
  12. #define int long long
  13. #define test int tt; cin>>tt; for(int cs=1;cs<=tt;cs++)
  14. #define get_lost_idiot return 0
  15. #define nl '\n'
  16. #define PI acos(-1.0)
  17.  
  18. const int G = 3;
  19. const int MOD =  998244353;
  20. const int N = (1 << 18) + 5;
  21.  
  22. int rev[N], w[N], inv_n;
  23.  
  24. int bigMod (int a, int e, int mod)
  25. {
  26.     if (e == -1) e = mod - 2;
  27.     int ret = 1;
  28.     while (e)
  29.     {
  30.         if (e & 1) ret = (ll) ret * a % mod;
  31.         a = (ll) a * a % mod;
  32.         e >>= 1;
  33.     }
  34.     return ret;
  35. }
  36.  
  37. void prepare (int &n)
  38. {
  39.     int sz = abs(31 - __builtin_clz(n));
  40.     int r = bigMod(G, (MOD - 1) / n, MOD);
  41.     inv_n = bigMod(n, MOD - 2, MOD), w[0] = w[n] = 1;
  42.     for (int i = 1; i < n; ++i) w[i] = (ll) w[i - 1] * r % MOD;
  43.     for (int i = 1; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (sz - 1));
  44. }
  45.  
  46. void ntt (int *a, int n, int dir)
  47. {
  48.     for (int i = 1; i < n - 1; ++i)
  49.     {
  50.         if (i < rev[i]) swap(a[i], a[rev[i]]);
  51.     }
  52.     for (int m = 2; m <= n; m <<= 1)
  53.     {
  54.         for (int i = 0; i < n; i += m)
  55.         {
  56.             for (int j = 0; j < (m >> 1); ++j)
  57.             {
  58.                 int &u = a[i + j], &v = a[i + j + (m >> 1)];
  59.                 int t = (ll) v * w[dir ? n - n / m * j : n / m * j] % MOD;
  60.                 v = u - t < 0 ? u - t + MOD : u - t;
  61.                 u = u + t >= MOD ? u + t - MOD : u + t;
  62.             }
  63.         }
  64.     }
  65.     if (dir) for (int i = 0; i < n; ++i) a[i] = (ll) a[i] * inv_n % MOD;
  66. }
  67.  
  68. int f_a[N], f_b[N];
  69.  
  70. vector <int> multiply (vector <int> a, vector <int> b)
  71. {
  72.     int sz = 1, n = a.size(), m = b.size();
  73.     while (sz < n + m - 1) sz <<= 1;
  74.     prepare(sz);
  75.     for (int i = 0; i < sz; ++i) f_a[i] = i < n ? a[i] : 0;
  76.     for (int i = 0; i < sz; ++i) f_b[i] = i < m ? b[i] : 0;
  77.     ntt(f_a, sz, 0);
  78.     ntt(f_b, sz, 0);
  79.     for (int i = 0; i < sz; ++i) f_a[i] = (ll) f_a[i] * f_b[i] % MOD;
  80.     ntt(f_a, sz, 1);
  81.     return vector <int> (f_a, f_a + n + m - 1);
  82. }
  83.  
  84. vector<int>v;
  85. ll modv[100005];
  86. vector<long long> divideAndConqure(int l, int r)
  87. {
  88.     if (l == r)
  89.     {
  90.         vector<long long> k = {1,modv[ v[l]]};
  91.         return k;
  92.     }
  93.     int mid = (l + r) >> 1;
  94.     vector<long long> x = divideAndConqure(l, mid);
  95.     vector<long long> y = divideAndConqure(mid + 1, r);
  96.     return multiply(x, y);
  97. }
  98.  
  99.  
  100. main()
  101. {
  102.     for(ll x1=1; x1<=100001; x1++)
  103.     {
  104.         modv[x1]=(bigMod(2ll, x1, MOD) - 1ll + MOD) %MOD;
  105.     }
  106.  
  107.     int ttt;
  108.  
  109.     scanf("%lld",&ttt);
  110.  
  111.     for(int ca=1; ca<=ttt; ca++)
  112.     {
  113.         int i,n,m,j,k,l;
  114.  
  115.         scanf("%lld %lld",&n,&k);
  116.  
  117.         int cnt[n+4]= {0ll};
  118.  
  119.         for(i=0; i<n; i++)
  120.         {
  121.             scanf("%lld",&m);
  122.             cnt[m]++;
  123.         }
  124.  
  125.         v.clear();
  126.  
  127.         for(i=1; i<=n; i++)
  128.         {
  129.             if(cnt[i])
  130.             {
  131.                 v.push_back(cnt[i]);
  132.             }
  133.         }
  134.  
  135.         ll sz1=v.size()-1;
  136.  
  137.         vector<long long>ansvector = divideAndConqure(0,sz1);
  138.  
  139.         ll ans=0;
  140.  
  141.         ll sz=ansvector.size();
  142.  
  143.         if(sz>k)
  144.         {
  145.             for(i=k; i<sz; i++)
  146.             {
  147.                 ans+=ansvector[i];
  148.                 if(ans>=MOD) ans-=MOD;
  149.             }
  150.         }
  151.  
  152.         printf("Case %lld: %lld\n",ca,ans);
  153.     }
  154.  
  155.     get_lost_idiot;
  156. }
  157.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement