Dang_Quan_10_Tin

FFT (Mod 998244353)

Jul 19th, 2022 (edited)
172
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.00 KB | None | 0 0
  1. constexpr int N = 1e5 + 5; // keep N double of n+m
  2.  
  3. constexpr ll mod = 997244353;
  4. ll Pow(ll a, ll b, const ll &mod = 998244353)
  5. {
  6.     ll ans(1);
  7.  
  8.     for (; b; b >>= 1)
  9.     {
  10.         if (b & 1)
  11.             ans = ans * a % mod;
  12.         a = a * a % mod;
  13.     }
  14.  
  15.     return ans;
  16. }
  17.  
  18. namespace ntt
  19. {
  20.     const int N = ::N;
  21.     const long long mod = ::mod, rt = 3;
  22.  
  23.     ll G[55], iG[55], itwo[55];
  24.  
  25.     void add(int &a, int b)
  26.     {
  27.         a += b;
  28.         if (a >= mod)
  29.             a -= mod;
  30.     }
  31.  
  32.     void init()
  33.     {
  34.         int now = (mod - 1) / 2, len = 1, irt = Pow(rt, mod - 2, mod);
  35.         while (now % 2 == 0)
  36.         {
  37.             G[len] = Pow(rt, now, mod);
  38.             iG[len] = Pow(irt, now, mod);
  39.             itwo[len] = Pow(1 << len, mod - 2, mod);
  40.             now >>= 1;
  41.             len++;
  42.         }
  43.     }
  44.  
  45.     void dft(ll *x, int n, int fg = 1) // fg=1 for dft, fg=-1 for inverse dft
  46.     {
  47.         for (int i = (n >> 1), j = 1, k; j < n; ++j)
  48.         {
  49.             if (i < j)
  50.                 swap(x[i], x[j]);
  51.             for (k = (n >> 1); k & i; i ^= k, k >>= 1)
  52.                 ;
  53.             i ^= k;
  54.         }
  55.         for (int m = 2, now = 1; m <= n; m <<= 1, now++)
  56.         {
  57.             ll r = fg > 0 ? G[now] : iG[now];
  58.             for (int i = 0, j; i < n; i += m)
  59.             {
  60.                 ll tr = 1, u, v;
  61.                 for (j = i; j < i + (m >> 1); ++j)
  62.                 {
  63.                     u = x[j];
  64.                     v = x[j + (m >> 1)] * tr % mod;
  65.                     x[j] = (u + v) % mod;
  66.                     x[j + (m >> 1)] = (u + mod - v) % mod;
  67.                     tr = tr * r % mod;
  68.                 }
  69.             }
  70.         }
  71.     }
  72.  
  73.     void brute(ll *a, ll *b, int n, int m)
  74.     {
  75.         static ll c[N];
  76.         for (int k = 0, t; k < n + m - 1; ++k)
  77.         {
  78.             t = 0;
  79.             for (int i = max(k - m + 1, 0); i < n && i <= k; ++i)
  80.             {
  81.                 add(t, a[i] * b[k - i] % mod);
  82.                 // if(k==2&&i==1)OUT(a[i]),OUT(b[k-i]);
  83.             }
  84.             c[k] = t;
  85.         }
  86.         for (int k = 0; k < n + m - 1; ++k)
  87.             a[k] = c[k];
  88.     }
  89.  
  90.     // Take two sequence a, b; return answer in sequence a
  91.  
  92.     void mul(ll *a, ll *b, int n, int m)
  93.     {
  94.         // a: 0,1,2,...,n-1; b: 0,1,2,...,m-1
  95.  
  96.         int nn = n + m - 1;
  97.  
  98.         if (n == 0 || m == 0)
  99.         {
  100.             memset(a, 0, nn * sizeof(a[0]));
  101.             return;
  102.         }
  103.  
  104.         int L, len;
  105.  
  106.         for (L = 1, len = 0; L < nn; ++len, L <<= 1)
  107.             ;
  108.         if (n < L)
  109.             memset(a + n, 0, (L - n) * sizeof(a[0]));
  110.         if (m < L)
  111.             memset(b + m, 0, (L - m) * sizeof(b[0]));
  112.  
  113.         dft(a, L, 1); // dft(a)
  114.         dft(b, L, 1); // dft(b)
  115.  
  116.         // Merge
  117.         for (int i = 0; i < L; ++i)
  118.             a[i] = a[i] * b[i] % mod;
  119.  
  120.         // Interpolation
  121.         dft(a, L, -1);
  122.  
  123.         for (int i = 0; i < L; ++i)
  124.             a[i] = a[i] * itwo[len] % mod;
  125.     }
  126. }
Add Comment
Please, Sign In to add comment