Advertisement
Guest User

FFT

a guest
Dec 8th, 2019
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 8.37 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4. #include <algorithm>
  5.  
  6. using namespace std;
  7.  
  8. long long mod(long long x, long long mod)
  9. {
  10.     return ((x % mod) + mod) % mod;
  11. }
  12.  
  13. long long bin_pow_mod(long long a, long long b, long long mod)
  14. {
  15.     if (b == 0)
  16.         return 1;
  17.     if (b == 1)
  18.         return a % mod;
  19.     if (b % 2 == 0)
  20.     {
  21.         long long x = bin_pow_mod(a, b / 2, mod) % mod;
  22.         return (x * x) % mod;
  23.     }
  24.     else
  25.         return (a * bin_pow_mod(a, b - 1, mod) ) % mod;
  26. }
  27.  
  28. vector <long long> to_binary(long long a)
  29. {
  30.     vector <long long> ans;
  31.     while (a > 0)
  32.     {
  33.         ans.push_back(a % 2);
  34.         a /= 2;
  35.     }
  36.     reverse(ans.begin(), ans.end());
  37.     return ans;
  38. }
  39.  
  40. long long rev(long long x, long long k, long long m)
  41. {
  42.     long long ans = 0;
  43.     vector <int> vec(k);
  44.     long long i = 0;
  45.     while (x > 0)
  46.     {
  47.         vec[i] = x % 2;
  48.         x /= 2;
  49.         i++;
  50.     }
  51.     long long d = 0;
  52.     for (i = k - 1; i >= 0; i--)
  53.         ans += vec[i] * bin_pow_mod(2, (d++), m);
  54.     return ans;
  55. }
  56.  
  57. long long inserted_digit(long long digit, long long i, long long l, long long count_of_bits, long long m)
  58. {
  59.     vector<int> vec_i(count_of_bits);
  60.     long long w = 0;
  61.     while (i > 0)
  62.     {
  63.         vec_i[w] = i % 2;
  64.         i /= 2;
  65.         w++;
  66.     }
  67.     vec_i[l] = digit;
  68.     long long d = 0, ans = 0;
  69.     for (long long qq = 0; qq < count_of_bits; qq++)
  70.         ans += vec_i[qq] * bin_pow_mod(2, (d++), m);
  71.     return ans;
  72. }
  73.  
  74. long long extended_euclid(long long a, long long b, long long& x, long long& y)
  75. {
  76.     if (a == 0)
  77.     {
  78.         x = 0;
  79.         y = 1;
  80.         return b;
  81.     }
  82.     long long x1, y1;
  83.     long long d = extended_euclid(b % a, a, x1, y1);
  84.     x = y1 - (b / a) * x1;
  85.     y = x1;
  86.     return d;
  87. }
  88.  
  89. long long reverse_element(long long a, long long m)
  90. {
  91.     long long x, y;
  92.     long long g = extended_euclid(a, m, x, y);
  93.     if (g != 1)
  94.     {
  95.         cout << a << " не имеет обратного элемента по модулю " << m << endl;
  96.         exit(0);
  97.     }
  98.     else
  99.     {
  100.         x = mod(x, m);
  101.         return x;
  102.     }
  103. }
  104.  
  105. vector <long long> FFT(long long n, long long k, vector<long long> a, long long w, long long m)
  106. {
  107.     vector <long long> b(n), s(n), r(n);
  108.     // 1
  109.     for (long long i = 0; i < n; i++)
  110.         r[i] = a[i];
  111.     // 2
  112.     for (long long l = k - 1; l >= 0; l--)
  113.     {
  114.         for (long long i = 0; i < n; i++)
  115.             s[i] = r[i];
  116.         for (long long i = 0; i < n; i++)
  117.         {
  118.             long long s0 = inserted_digit(0, i, l, k, m);
  119.             long long revi = rev(i / bin_pow_mod(2, l, m), k, m);
  120.             long long s1 = inserted_digit(1, i, l, k, m);
  121.             r[i] = s[s0] + bin_pow_mod(w, revi, m) * s[s1];
  122.             r[i] = mod(r[i], m);
  123.         }
  124.     }
  125.     // 3
  126.     for (long long i = 0; i < n; i++)
  127.     {
  128.         b[i] = r[rev(i, k, m)];
  129.     }
  130.     return b;
  131. }
  132.  
  133. vector <long long> reverse_FFT(long long n, long long k, vector<long long> a, long long w_1, long long m, long long n_1)
  134. {
  135.     vector <long long> b(n), s(n), r(n);
  136.     // 1
  137.     for (long long i = 0; i < n; i++)
  138.         r[i] = a[i];
  139.     // 2
  140.     for (long long l = k - 1; l >= 0; l--)
  141.     {
  142.         for (long long i = 0; i < n; i++)
  143.             s[i] = r[i];
  144.         for (long long i = 0; i < n; i++)
  145.         {
  146.             long long s0 = inserted_digit(0, i, l, k, m);
  147.             long long revi = rev(i / bin_pow_mod(2, l, m), k, m);
  148.             long long s1 = inserted_digit(1, i, l, k, m);
  149.             r[i] = s[s0] + bin_pow_mod(w_1, revi, m) * s[s1];
  150.             r[i] = mod(r[i], m);
  151.         }
  152.     }
  153.     // 3
  154.     for (long long i = 0; i < n; i++)
  155.     {
  156.         b[i] = n_1 * r[rev(i, k, m)];
  157.         b[i] = mod(b[i], m);
  158.     }
  159.     return b;
  160. }
  161.  
  162. vector <long long> divide(vector<long long> v, long long N, long long L)
  163. {
  164.     vector <long long> ans;
  165.     reverse(v.begin(), v.end());
  166.     while (v.size() < N)
  167.         v.push_back(0);
  168.     long long sum = 0, d = 1;
  169.     for (int i = 0; i < v.size(); i++)
  170.     {
  171.         sum += v[i] * d;
  172.         d *= 2;
  173.         if (i % L == L - 1)
  174.         {
  175.             ans.push_back(sum);
  176.             sum = 0;
  177.             d = 1;
  178.         }
  179.     }
  180.     if (sum > 0)
  181.         ans.push_back(sum);
  182.     return ans;
  183. }
  184.  
  185. long long SS(vector <long long> u, vector <long long> v, long long n, long long N)
  186. {
  187.     // 1
  188.     long long l = n/2, k = n - l, K = bin_pow_mod(2, k, LONG_MAX), L = bin_pow_mod(2, l, LONG_MAX);
  189.     u = divide(u, N, L);
  190.     v = divide(v, N, L);
  191.     // 2
  192.     vector <long long> W(u.size()), W_(u.size());
  193.     for (long long i = 0; i < K; i++)
  194.     {
  195.         long long sum1 = 0, sum2 = 0;
  196.         for (long long j = 0; j <= i; j++)
  197.             sum1 += u[i - j] * v[j];
  198.         for (long long j = i + 1; j < K; j++)
  199.             sum2 += u[i + K - j] * v[j];
  200.         W[i] = sum1 - sum2;
  201.         W_[i] = mod(W[i], K);
  202.     }
  203.     // 3
  204.     vector <long long> W__(u.size());
  205.     long long psi = bin_pow_mod(2, 2*L/K, LONG_MAX);
  206.     vector <long long> u_(u.size()), v_(u.size());
  207.     long long pr = 1;
  208.     for (long long i = 0; i < u.size(); i++)
  209.     {
  210.         u_[i] = u[i] * pr;
  211.         v_[i] = v[i] * pr;
  212.         if (i != u.size() - 1)
  213.             pr *= psi;
  214.     }
  215.     long long m = bin_pow_mod(2, 2 * L, LONG_MAX) + 1;
  216.     long long w = bin_pow_mod(2, 4 * L/K, LONG_MAX);
  217.  
  218.     vector <long long> u__ = FFT(u.size(), log(u.size() * 1.0) / log(2.0), u_, w, m);
  219.     vector <long long> v__ = FFT(v.size(), log(v.size() * 1.0) / log(2.0), v_, w, m);
  220.  
  221.     vector <long long> c(u.size());
  222.     for (long long i = 0; i < u.size(); i++)
  223.         c[i] = mod(u__[i] * v__[i], m);
  224.  
  225.     long long w_1 = -1 * bin_pow_mod(2, 2 * L - 4 * L / K, LONG_MAX);
  226.     vector <long long> d = reverse_FFT(c.size(), log(c.size() * 1.0) / log(2.0), c, mod(w_1, m), m, reverse_element(c.size(), m));
  227.  
  228.     long long psi_1 = -1 * bin_pow_mod(2, 2 * L - 2 * L / K, LONG_MAX);
  229.     pr = 1;
  230.     for (long long i = 0; i < d.size(); i++)
  231.     {
  232.         W__[i] = mod(d[i] * pr, m);
  233.         pr *= psi_1;
  234.     }
  235.     // 4
  236.     for (long long i = 0; i < K; i++)
  237.     {
  238.         long long W___ = (bin_pow_mod(2, 2 * L, LONG_MAX) + 1) * (mod(W_[i] - W__[i], K)) + W__[i];
  239.         if (W___ < (i+1) * bin_pow_mod(2, 2 * L, LONG_MAX))
  240.             W[i] = W___;
  241.         else
  242.             W[i] = W___ - K * (bin_pow_mod(2, 2 * L, LONG_MAX) + 1);
  243.     }
  244.     // 5
  245.     long long y = 0;
  246.     for (long long i = 0; i < W.size(); i++)
  247.         y += W[i] * bin_pow_mod(2, L * i, LONG_MAX);
  248.     return y;
  249. }
  250.  
  251. int main()
  252. {
  253.     setlocale(LC_ALL, "RUSSIAN");
  254.     int mode;
  255.     cout << "Введите 1 или 2: 1 - для запуска быстрого преобразования Фурье и обратного быстрого преобразования Фурье, 2 - для запуска алгоритма Шенхаге-Штрассена: ";
  256.     cin >> mode;
  257.     if (mode == 1)
  258.     {
  259.         long long n;
  260.         cout << "Введите n: ";
  261.         cin >> n;
  262.        
  263.         long long n1 = n, count = 0;
  264.         while (n1 != 1)
  265.         {
  266.             count++;
  267.             n1 /= 2;
  268.         }
  269.         long long k = count;
  270.        
  271.         vector <long long> a(n), b(n);
  272.         cout << "Введите элементы массива а: ";
  273.         for (long long i = 0; i < n; i++)
  274.             cin >> a[i];
  275.        
  276.         long long w = 2;
  277.         cout << "Преобразованный массив: ";
  278.         long long m = bin_pow_mod(w, n/2, LONG_MAX) + 1;
  279.         b = FFT(n, k, a, w, m);
  280.         for (long long i = 0; i < n; i++)
  281.             cout << b[i] << " ";
  282.         cout << endl;
  283.        
  284.         cout << "Обратно преобразованный массив: ";
  285.         a = reverse_FFT(n, k, b, reverse_element(w, m), m, reverse_element(n, m));
  286.         for (long long i = 0; i < n; i++)
  287.             cout << a[i] << " ";
  288.         cout << endl;
  289.     }
  290.     else
  291.         if (mode == 2)
  292.         {
  293.             long long a, b;
  294.             cout << "Введите два числа: ";
  295.             cin >> a >> b;
  296.  
  297.             vector <long long> aa = to_binary(a), bb = to_binary(b);
  298.             long long n = 7;
  299.             long long y = SS(aa, bb, n, bin_pow_mod(2, n, LONG_MAX));
  300.             cout << "Результат умножения чисел: " << y << endl;
  301.         }
  302.    
  303.     return 0;
  304. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement