Matrix_code

math - NTT

Feb 12th, 2017 (edited)
184
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.40 KB | None | 0 0
  1. // NTT
  2. // mod defines the modulo of the problem. In example 1e9+7
  3.  
  4.  
  5. #include <stdio.h>
  6. #include <bits/stdc++.h>
  7. using namespace std;
  8.  
  9.  
  10. #include <vector>
  11. #include <utility>
  12.  
  13.  
  14. template <typename T>
  15. T extGcd(T a, T b, T& x, T& y) {
  16.    if (b == 0) {
  17.       x = 1;
  18.       y = 0;
  19.       return a;
  20.    }
  21.    else {
  22.       int g = extGcd(b, a % b, y, x);
  23.       y -= a / b * x;
  24.       return g;
  25.    }
  26. }
  27.  
  28. template <typename T>
  29. T modInv(T a, T m) {
  30.    T x, y;
  31.    extGcd(a, m, x, y);
  32.    return (x % m + m) % m;
  33. }
  34.  
  35. long long crt(const std::vector< std::pair<int, int> >& pp, int mod = -1);
  36.  
  37.  
  38. #include <algorithm>
  39.  
  40.  
  41. struct FFT_mod {
  42.    int mod, root, root_1, root_pw;
  43. };
  44.  
  45. extern FFT_mod suggested_fft_mods[5];
  46. void ntt_shortmod(std::vector<int>& a, bool invert, const FFT_mod& mod_data);
  47.  
  48.  
  49. const int mod = 1000000007;
  50.  
  51. vector<int> mull(const vector<int>& left, const vector<int>& right, const FFT_mod& mod_data) {
  52.    vector<int> left1 = left, right1 = right;
  53.    ntt_shortmod(left1, false, mod_data);
  54.    ntt_shortmod(right1, false, mod_data);
  55.    
  56.    for (int i = 0; i < left.size(); i++) {
  57.       left1[i] = (left1[i] * 1ll * right1[i]) % mod_data.mod;
  58.    }
  59.    
  60.    ntt_shortmod(left1, true, mod_data);
  61.    return left1;
  62. }
  63.  
  64.  
  65. vector<int> mult(vector<int>& left, vector<int>& right) {
  66.    int ssss = left.size() + right.size() - 1;
  67.    int pot2;
  68.    for (pot2 = 1; pot2 < ssss; pot2 <<= 1);
  69.    
  70.    left.resize(pot2);
  71.    right.resize(pot2);
  72.    
  73.    vector<int> res[3];
  74.    for (int i = 0; i < 3; i++) {
  75.       res[i] = mull(left, right, suggested_fft_mods[i]);
  76.    }
  77.    
  78.    vector<int> ret(pot2);
  79.    for (int i = 0; i < pot2; i++) {
  80.       vector< pair<int,int> > mod_results;
  81.       for (int j = 0; j < 3; j++) {
  82.          mod_results.emplace_back(res[j][i], suggested_fft_mods[j].mod);
  83.       }
  84.       ret[i] = crt(mod_results, mod);
  85.    }
  86.    return ret;
  87. }
  88.  
  89. long long crt(const std::vector< std::pair<int, int> >& a, int mod) {
  90.    long long res = 0;
  91.    long long mult = 1;
  92.    
  93.    int SZ = a.size();
  94.    std::vector<int> x(SZ);
  95.    for (int i = 0; i<SZ; ++i) {
  96.       x[i] = a[i].first;
  97.       for (int j = 0; j<i; ++j) {
  98.          long long cur = (x[i] - x[j]) * 1ll * modInv(a[j].second,a[i].second);
  99.          x[i] = (int)(cur % a[i].second);
  100.          if (x[i] < 0) x[i] += a[i].second;
  101.       }
  102.       res = (res + mult * 1ll * x[i]);
  103.       mult = (mult * 1ll * a[i].second);
  104.       if (mod != -1) {
  105.          res %= mod;
  106.          mult %= mod;
  107.       }
  108.    }
  109.    
  110.    return res;
  111. }
  112.  
  113.  
  114. FFT_mod suggested_fft_mods[] = {
  115.    { 7340033, 5, 4404020, 1 << 20 },
  116.    { 415236097, 73362476, 247718523, 1<<22 },
  117.    { 463470593, 428228038, 182429, 1<<21},
  118.    { 998244353, 15311432, 469870224, 1 << 23 },
  119.    { 918552577, 86995699, 324602258, 1 << 22 }
  120. };
  121.  
  122. int FFT_w[1050000];
  123. int FFT_w_dash[1050000];
  124.  
  125.  
  126. void ntt_shortmod(std::vector<int>& a, bool invert, const FFT_mod& mod_data) {
  127.    // only use if mod < 5*10^8
  128.    int n = (int)a.size();
  129.    int mod = mod_data.mod;
  130.    
  131.    for (int i = 1, j = 0; i<n; ++i) {
  132.       int bit = n >> 1;
  133.       for (; j >= bit; bit >>= 1)
  134.          j -= bit;
  135.       j += bit;
  136.       if (i < j)
  137.          std::swap(a[i], a[j]);
  138.    }
  139.    
  140.    for (int len = 2; len <= n; len <<= 1) {
  141.       int wlen = invert ? mod_data.root_1 : mod_data.root;
  142.       for (int i = len; i<mod_data.root_pw; i <<= 1)
  143.          wlen = int(wlen * 1ll * wlen % mod_data.mod);
  144.      
  145.       long long tt = wlen;
  146.       for (int i = 1; i < len / 2; i++) {
  147.          FFT_w[i] = tt;
  148.          FFT_w_dash[i] = (tt << 31) / mod;
  149.          int q = (FFT_w_dash[1] * 1ll * tt) >> 31;
  150.          tt = (wlen * 1ll * tt - q * 1ll * mod) & ((1LL << 31) - 1);
  151.          if (tt >= mod) tt -= mod;
  152.       }
  153.       for (int i = 0; i<n; i += len) {
  154.          int uu = a[i], vv = a[i + len / 2] % mod;
  155.          if (uu >= 2*mod) uu -= 2*mod;
  156.          a[i] = uu + vv;
  157.          a[i + len / 2] = uu - vv + 2 * mod;
  158.          
  159.          for (int j = 1; j<len / 2; ++j) {
  160.             int u = a[i + j];
  161.             if (u >= 2*mod) u -= 2*mod;
  162.             int q = (FFT_w_dash[j] * 1ll * a[i + j + len / 2]) >> 31;
  163.             int v = (FFT_w[j] * 1ll * a[i + j + len / 2] - q * 1ll * mod) & ((1LL << 31) - 1);
  164.             a[i + j] = u + v;
  165.             a[i + j + len / 2] = u - v + 2*mod;
  166.          }
  167.       }
  168.    }
  169.    if (invert) {
  170.       int nrev = modInv(n, mod);
  171.       for (int i = 0; i<n; ++i)
  172.          a[i] = int(a[i] * 1ll * nrev % mod);
  173.    }
  174. }
Add Comment
Please, Sign In to add comment