Advertisement
TrickmanOff

Karatsuba arrays + debug

Sep 25th, 2019
247
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.71 KB | None | 0 0
  1. #pragma optimization_level 3
  2. #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math,O3")
  3. #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
  4. #include <iostream>
  5. #include <algorithm>
  6. #include <fstream>
  7. #include <vector>
  8. #include <queue>
  9. #include <stack>
  10. #include <functional>
  11. #include <set>
  12. #include <map>
  13. #include <math.h>
  14. #include <cmath>
  15. #include <string>
  16. #include <time.h>
  17. #include <random>
  18. #include <unordered_set>
  19. #include <unordered_map>
  20. #include <bitset>
  21. #include <string.h>
  22. #include <complex>
  23. #include <ctime>
  24. using namespace std;
  25.  
  26. #define fast cin.tie(0);cout.tie(0);cin.sync_with_stdio(0);cout.sync_with_stdio(0);
  27. #define cin in
  28. //#define cout out
  29. #define pii pair<int,int>
  30. //#define ll long long
  31. #define db double
  32. #define ld long double
  33. #define uset unordered_set
  34. #define umap unordered_map
  35. #define vec vector
  36. #define ms multiset
  37. #define pb push_back
  38. #define pll pair<ll,ll>
  39. #define pdd pair<ld, ld>
  40. #define pq priority_queue
  41. #define umap unordered_map
  42. #define uset unordered_set
  43. #define pnn pair<Node*, Node*>
  44. #define uid uniform_int_distribution
  45.  
  46. typedef long long ll;
  47. typedef unsigned int uint;
  48.  
  49. ifstream in("input.txt");
  50. ofstream out("output.txt");
  51.  
  52. typedef unsigned int digit;
  53. const int LIM = 32;
  54.  
  55. struct poly {
  56.     digit* val;
  57.     int len;
  58.  
  59.     digit operator[](int pos) {
  60.         if (pos >= len)
  61.             return 0;
  62.         return val[pos];
  63.     }
  64. };
  65.  
  66. poly sum(poly a, poly b) {
  67.     poly s;
  68.     if (a.len < b.len)
  69.         swap(a, b);
  70.  
  71.     s.len = a.len;
  72.     s.val = new digit[s.len];
  73.  
  74.     for (int i = 0; i < a.len; i++)
  75.         s.val[i] = a[i] + b[i];
  76.  
  77.     return s;
  78. }
  79.  
  80. poly& sub(poly& a, poly b) {
  81.     for (int i = 0; i < a.len; i++)
  82.         a.val[i] -= b[i];
  83.     return a;
  84. }
  85.  
  86. void print(poly& a) {
  87.     for (int i = 0; i < a.len; i++)
  88.         cout << a[i] << ' ';
  89.     cout << "  len: " << a.len;
  90.     cout << '\n';
  91. }
  92.  
  93. bool DEBUG = 0;
  94.  
  95. poly karatsuba(poly a, poly b) {
  96.     poly res;
  97.     int n = a.len;
  98.  
  99.     res.len = 2 * n - 1;
  100.     res.val = new digit[res.len];
  101.  
  102.     if (a.len <= LIM) {
  103.         memset(res.val, 0, sizeof(digit) * res.len);
  104.        
  105.         if (DEBUG) {
  106.             cout << "Naive multiplication\n";
  107.             cout << "First: ";
  108.             print(a);
  109.             cout << "Second: ";
  110.             print(b);
  111.         }
  112.        
  113.         for (int i = 0; i < a.len; i++)
  114.             for (int j = 0; j < b.len; j++) {
  115.                 res.val[i + j] += a[i] * b[j];
  116.                 if (DEBUG) {
  117.                     cout << "res: ";
  118.                     print(res);
  119.                 }
  120.             }
  121.         return res;
  122.     }
  123.  
  124.     poly a_part1;
  125.     a_part1.val = a.val;
  126.     a_part1.len = n / 2;
  127.     if (DEBUG) {
  128.         cout << "a_part1: ";
  129.         print(a_part1);
  130.     }
  131.  
  132.     poly a_part2;
  133.     a_part2.val = a.val + a_part1.len;
  134.     a_part2.len = n - a_part1.len;
  135.     if (DEBUG) {
  136.         cout << "a_part2: ";
  137.         print(a_part2);
  138.     }
  139.  
  140.     poly b_part1;
  141.     b_part1.val = b.val;
  142.     b_part1.len = n / 2;
  143.     if (DEBUG) {
  144.         cout << "b_part1: ";
  145.         print(b_part1);
  146.     }
  147.  
  148.     poly b_part2;
  149.     b_part2.val = b.val + b_part1.len;
  150.     b_part2.len = n - b_part1.len;
  151.     if (DEBUG) {
  152.         cout << "b_part2: ";
  153.         print(b_part2);
  154.     }
  155.  
  156.     poly sum_of_a_parts = sum(a_part1, a_part2);
  157.     if (DEBUG) {
  158.         cout << "sum_of_a_parts: ";
  159.         print(sum_of_a_parts);
  160.     }
  161.     poly sum_of_b_parts = sum(b_part1, b_part2);
  162.     if (DEBUG) {
  163.         cout << "sum_of_b_parts: ";
  164.         print(sum_of_b_parts);
  165.     }
  166.  
  167.     poly product_of_sums_of_parts = karatsuba(sum_of_a_parts, sum_of_b_parts);
  168.     if (DEBUG) {
  169.         cout << "product_of_sums_of_parts: ";
  170.         print(product_of_sums_of_parts);
  171.     }
  172.  
  173.     poly product_of_first_parts = karatsuba(a_part1, b_part1);
  174.     if (DEBUG) {
  175.         cout << "product_of_first_parts: ";
  176.         print(product_of_first_parts);
  177.     }
  178.    
  179.     poly product_of_second_parts = karatsuba(a_part2, b_part2);
  180.     if (DEBUG) {
  181.         cout << "product_of_second_parts: ";
  182.         print(product_of_second_parts);
  183.     }
  184.  
  185.     poly sum_of_middle_terms = sub(sub(product_of_sums_of_parts, product_of_first_parts), product_of_second_parts);
  186.     if (DEBUG) {
  187.         cout << "sum_of_middle_terms: ";
  188.         print(sum_of_middle_terms);
  189.     }
  190.  
  191.     memset(res.val, 0, res.len * sizeof(digit));
  192.  
  193.     memcpy(res.val, product_of_first_parts.val, product_of_first_parts.len * sizeof(digit));
  194.     if (DEBUG) {
  195.         cout << "res: ";
  196.         print(res);
  197.     }
  198.     memcpy(res.val + 2 * (n/2), product_of_second_parts.val, product_of_second_parts.len * sizeof(digit));
  199.     if (DEBUG) {
  200.         cout << "res: ";
  201.         print(res);
  202.     }
  203.  
  204.     for (int i = 0; i < sum_of_middle_terms.len; i++)
  205.         res.val[a_part1.len + i] += sum_of_middle_terms[i];
  206.  
  207.     //
  208.     //зачистка
  209.     //
  210.  
  211.     delete[] sum_of_a_parts.val;
  212.     delete[] sum_of_b_parts.val;
  213.     delete[] product_of_sums_of_parts.val;
  214.     delete[] product_of_first_parts.val;
  215.     delete[] product_of_second_parts.val;
  216.  
  217.     return res;
  218. }
  219.  
  220. poly read_poly(string& s) {
  221.     vector<digit> nums;
  222.  
  223.     digit num = 0;
  224.     for (char x : s) {
  225.         if (x == ' ') {
  226.             nums.push_back(num);
  227.             num = 0;
  228.         }
  229.         else
  230.             num = num * 10 + x - '0';
  231.     }
  232.     nums.push_back(num);
  233.  
  234.     poly cur;
  235.     cur.val = new digit[nums.size()];
  236.     cur.len = nums.size();
  237.     for (int i = 0; i < cur.len; i++)
  238.         cur.val[i] = nums[i];
  239.  
  240.     return cur;
  241. }
  242.  
  243. void format(poly &a, poly &b) {
  244.     poly s;
  245.     if (a.len < b.len)
  246.         swap(a, b);
  247.  
  248.     s.val = new digit[a.len];
  249.     s.len = a.len;
  250.     memset(s.val, 0, s.len * sizeof(digit));
  251.     //print(s);
  252.     memcpy(s.val, b.val, b.len * sizeof(digit));
  253.     //print(s);
  254.     b = s;
  255. }
  256.  
  257. int main()
  258. {
  259.     fast;
  260.     unsigned int start_time = clock();
  261.     string s;
  262.     int fin_n = 0;
  263.     getline(cin, s);
  264.     poly res = read_poly(s);
  265.     fin_n += res.len - 1;
  266.  
  267.     while (getline(cin, s)) {
  268.         poly b = read_poly(s);
  269.         fin_n += b.len - 1;
  270.         format(res, b);
  271.  
  272.         //cout << "first: ";
  273.         //print(res);
  274.         //cout << "second: ";
  275.         //print(b);
  276.  
  277.         res = karatsuba(res, b);
  278.         //cout << "result: ";
  279.         //print(res);
  280.         //cout << '\n';
  281.     }
  282.  
  283.     //for (int i = 0; i <= fin_n; i++)
  284.         //cout << res[i] << ' ';
  285.     unsigned int end_time = clock();
  286.     cout << (end_time - start_time) / 1000.0;
  287. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement