Guest User

Untitled

a guest
Mar 17th, 2024
108
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.32 KB | None | 0 0
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. using ll = long long;
  4. using pii = pair <int, int>;
  5. vector <int> a;
  6. template <typename t> struct neg_array {
  7.     int low, high;
  8.     vector<t> content;
  9.     neg_array(int _low, int _high) : low(_low), high(_high), content(high - low + 1) {
  10.         assert(high >= low);
  11.     }
  12.     t&operator[](int x) {
  13.         assert(x >= low && x <= high);
  14.         return content[x - low];
  15.     }
  16.     const t&operator[](int x) const {
  17.         assert(x >= low && x <= high);
  18.         return content[x - low];
  19.     }
  20. };
  21. using block = neg_array<int>;
  22. pii calc_range(vector <int> coefs) {
  23.     int low = 0, high = 0;
  24.     int mi = min(0, *min_element(a.begin(), a.end()));
  25.     int MA = max(0, *max_element(a.begin(), a.end()));
  26.     for (int c : coefs) {
  27.         if (c > 0) {
  28.             low += c * mi;
  29.             high += c * MA;
  30.         }
  31.         else {
  32.             low += c * MA;
  33.             high += c * mi;
  34.         }
  35.     }
  36.     return {low, high};
  37. }
  38. ll calc_matches(const block &a, const block &b, bool inverted) {
  39.     int sign_b = inverted ? 1 : -1;
  40.     ll ans = 0;
  41.     int l = a.low;
  42.     int h = a.high;
  43.     if (!inverted) {
  44.         l = max(l, -b.high);
  45.         h = min(h, -b.low);
  46.     }
  47.     else {
  48.         l = max(l, b.low);
  49.         h = min(h, b.high);
  50.     }
  51.     for (int i = l; i <= h; ++i) ans += a[i] * 1ll * b[i * sign_b];
  52.     return ans;
  53. }
  54. const int LESS = 0, LEQ = 1, EQ = 2;
  55. block make(int si, int sj, int sk, vector <tuple <int, int, int> > comparisons) {
  56.     int n = (int)a.size();
  57.     auto [low, high] = calc_range({si, sj, sk});
  58.     block b(low, high);
  59.     for (int i = 0; i < n; ++i) {
  60.         for (int j = 0; j < n; ++j) {
  61.             for (int k = 0; k < n; ++k) {
  62.                 int indices[] = {i, j, k};
  63.                 bool good = true;
  64.                 for (auto [x, dir, y] : comparisons) {
  65.                     if (dir == LESS) good = good && indices[x] < indices[y];
  66.                     if (dir == LEQ) good = good && indices[x] <= indices[y];
  67.                     if (dir == EQ) good = good && indices[x] == indices[y];
  68.                 }
  69.                 if (good) {
  70.                     b[a[i] * si + a[j] * sj + a[k] * sk]++;
  71.                 }
  72.             }
  73.         }
  74.     }
  75.     return b;
  76. }
  77. ll ans = 0;
  78. ll safe_div(ll x, ll y) {
  79.     assert(x % y == 0);
  80.     return x / y;
  81. }
  82. block make_ranges(bool with_empty, int scale = 1) {
  83.     auto [low, high] = calc_range({scale, -scale});
  84.     block b(low, high);
  85.     int n = (int) a.size();
  86.     for (int i = 0; i < n; ++i) for (int j = (with_empty ? i : i + 1); j < n; ++j) b[scale*(a[j] - a[i])]++;
  87.     return b;
  88. }
  89. template <class fun> void assert_val(ll should, int line, fun Valid) {
  90.     vector <tuple<int, int, int, int, int, int> > found;
  91.     const int n = (int)a.size();
  92. #define REP(x) for (int x = 0; x < n; ++x)
  93.     REP(lx) REP(rx) REP(ly) REP(ry) REP(lz) REP(rz) {
  94.         if (a[lx] + a[ly] + a[lz] - a[rx] - a[ry] - a[rz] == 0 && Valid(lx, rx, ly, ry, lz, rz)) {
  95.             found.emplace_back(lx, rx, ly, ry, lz, rz);
  96.         }
  97.     }
  98.     if (should != (ll) found.size()) {
  99.         cerr << "Found: " << found.size() << "\n";
  100.         for (auto [lx, rx, ly, ry, lz, rz] : found) {
  101.             #define make_var(id) "(a[" #id " = " << id << "] = " << a[id] << ")"
  102.             cerr << make_var(lx) " - " make_var(rx) " + " make_var(ly) " - " make_var(ry) " + " make_var(lz) " - " make_var(rz) << "\n";
  103.         }
  104.         cerr << "Expected " << should << " (at line " << line << ")" << endl;
  105.         abort();
  106.     }
  107.     cerr << "OK: Found: " << found.size() << " (at line " << line << ")" << endl;
  108. }
  109. #ifdef LOCAL
  110. #define run_test(val, cond) assert_val(val, __LINE__, [](int lx, int rx, int ly, int ry, int lz, int rz){return (cond);});
  111. #else
  112. #define run_test(...)
  113. #endif
  114. ll step_1() {
  115.     auto full_block = make(1, 1, 1, {});
  116.     ll full = calc_matches(full_block, full_block, true); //counts all choices of lx, ly, lz, rx, ry, rz s.t. a[lx] - a[rx] + a[ly] - a[ry] + a[lz] - a[rz] = 0 with absolutely no regard to any order on uniqueness
  117.     return full;
  118. }
  119. ll step_2() {
  120.     auto invalid_ranges_lx_rx = make(1, -1, 1, {{0, LEQ, 1}}); //Options such that lx ≤ rx, and lz can be anything
  121.     auto invalid_ranges_ly_ry = make(1, -1, -1, {{1, LESS, 0}}); //Options such that ly > ry, and rz can be anything
  122.  
  123.     ll misordered = calc_matches(invalid_ranges_lx_rx, invalid_ranges_ly_ry, false); //All solutions such that lx ≤ rx and ly > ry
  124.     return misordered;
  125. }
  126. int main() {
  127.     int n;
  128.     cin >> n;
  129.     a.resize(n + 1);
  130.     for (int i = 1; i <= n; ++i) cin >> a[i];
  131.     for (int i = 1; i <= n; ++i) a[i] += a[i - 1];
  132.  
  133.     ll full = step_1();
  134.  
  135.     run_test(full, true);
  136.  
  137.     ll misordered = step_2();
  138.  
  139.     full -= 3 * misordered; //Subtracting 3 times excludes (lx ≤ rx && ly > ry), (ly ≤ ry && lz > rz) and (lz ≤ rz && lx > rx)
  140.     //Leaving only (lx > rx && ly > ry && lz > rz) or (lx ≤ rx && ly ≤ ry && lz ≤ rz)
  141.  
  142.     run_test(full, (lx > rx && ly > ry && lz > rz) or (lx <= rx && ly <= ry && lz <= rz));
  143.  
  144.     ll real_degenerate_any = calc_matches(make_ranges(true), make_ranges(false), false) * (n + 1ll);
  145.  
  146.     full -= 3 * real_degenerate_any;
  147.    
  148.     run_test(full, (lx > rx && ly > ry && lz > rz) or (lx < rx && ly < ry && lz < rz) or (lx == rx && ly == ry && lz == rz));
  149.  
  150.     full -= (n + 1) * 1ll * (n + 1) * 1ll * (n + 1);
  151.  
  152.     run_test(full, (lx > rx && ly > ry && lz > rz) or (lx < rx && ly < ry && lz < rz));
  153.  
  154.     full = safe_div(full, 2);
  155.  
  156.     run_test(full, (lx < rx && ly < ry && lz < rz));
  157.  
  158.     ll pairs = calc_matches(make_ranges(false, 2), make_ranges(false), false);
  159.  
  160.     full -= 3 * pairs;
  161.  
  162.     ll zero_ranges = make_ranges(false)[0];
  163.  
  164.     full += 2 * zero_ranges;
  165.  
  166.     full = safe_div(full, 6);
  167.  
  168.     run_test(full, (lx < rx && ly < ry && lz < rz) && pii(lx, rx) < pii(ly, ry) && pii(ly, ry) < pii(lz, rz));
  169.    
  170.     cout << full << endl;
  171. }
  172.  
Advertisement
Add Comment
Please, Sign In to add comment