Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- using namespace std;
- using ll = long long;
- using pii = pair <int, int>;
- vector <int> a;
- template <typename t> struct neg_array {
- int low, high;
- vector<t> content;
- neg_array(int _low, int _high) : low(_low), high(_high), content(high - low + 1) {
- assert(high >= low);
- }
- t&operator[](int x) {
- assert(x >= low && x <= high);
- return content[x - low];
- }
- const t&operator[](int x) const {
- assert(x >= low && x <= high);
- return content[x - low];
- }
- };
- using block = neg_array<int>;
- pii calc_range(vector <int> coefs) {
- int low = 0, high = 0;
- int mi = min(0, *min_element(a.begin(), a.end()));
- int MA = max(0, *max_element(a.begin(), a.end()));
- for (int c : coefs) {
- if (c > 0) {
- low += c * mi;
- high += c * MA;
- }
- else {
- low += c * MA;
- high += c * mi;
- }
- }
- return {low, high};
- }
- ll calc_matches(const block &a, const block &b, bool inverted) {
- int sign_b = inverted ? 1 : -1;
- ll ans = 0;
- int l = a.low;
- int h = a.high;
- if (!inverted) {
- l = max(l, -b.high);
- h = min(h, -b.low);
- }
- else {
- l = max(l, b.low);
- h = min(h, b.high);
- }
- for (int i = l; i <= h; ++i) ans += a[i] * 1ll * b[i * sign_b];
- return ans;
- }
- const int LESS = 0, LEQ = 1, EQ = 2;
- block make(int si, int sj, int sk, vector <tuple <int, int, int> > comparisons) {
- int n = (int)a.size();
- auto [low, high] = calc_range({si, sj, sk});
- block b(low, high);
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < n; ++j) {
- for (int k = 0; k < n; ++k) {
- int indices[] = {i, j, k};
- bool good = true;
- for (auto [x, dir, y] : comparisons) {
- if (dir == LESS) good = good && indices[x] < indices[y];
- if (dir == LEQ) good = good && indices[x] <= indices[y];
- if (dir == EQ) good = good && indices[x] == indices[y];
- }
- if (good) {
- b[a[i] * si + a[j] * sj + a[k] * sk]++;
- }
- }
- }
- }
- return b;
- }
- ll ans = 0;
- ll safe_div(ll x, ll y) {
- assert(x % y == 0);
- return x / y;
- }
- block make_ranges(bool with_empty, int scale = 1) {
- auto [low, high] = calc_range({scale, -scale});
- block b(low, high);
- int n = (int) a.size();
- for (int i = 0; i < n; ++i) for (int j = (with_empty ? i : i + 1); j < n; ++j) b[scale*(a[j] - a[i])]++;
- return b;
- }
- template <class fun> void assert_val(ll should, int line, fun Valid) {
- vector <tuple<int, int, int, int, int, int> > found;
- const int n = (int)a.size();
- #define REP(x) for (int x = 0; x < n; ++x)
- REP(lx) REP(rx) REP(ly) REP(ry) REP(lz) REP(rz) {
- if (a[lx] + a[ly] + a[lz] - a[rx] - a[ry] - a[rz] == 0 && Valid(lx, rx, ly, ry, lz, rz)) {
- found.emplace_back(lx, rx, ly, ry, lz, rz);
- }
- }
- if (should != (ll) found.size()) {
- cerr << "Found: " << found.size() << "\n";
- for (auto [lx, rx, ly, ry, lz, rz] : found) {
- #define make_var(id) "(a[" #id " = " << id << "] = " << a[id] << ")"
- cerr << make_var(lx) " - " make_var(rx) " + " make_var(ly) " - " make_var(ry) " + " make_var(lz) " - " make_var(rz) << "\n";
- }
- cerr << "Expected " << should << " (at line " << line << ")" << endl;
- abort();
- }
- cerr << "OK: Found: " << found.size() << " (at line " << line << ")" << endl;
- }
- #ifdef LOCAL
- #define run_test(val, cond) assert_val(val, __LINE__, [](int lx, int rx, int ly, int ry, int lz, int rz){return (cond);});
- #else
- #define run_test(...)
- #endif
- ll step_1() {
- auto full_block = make(1, 1, 1, {});
- ll full = calc_matches(full_block, full_block, true); //counts all choices of lx, ly, lz, rx, ry, rz s.t. a[lx] - a[rx] + a[ly] - a[ry] + a[lz] - a[rz] = 0 with absolutely no regard to any order on uniqueness
- return full;
- }
- ll step_2() {
- auto invalid_ranges_lx_rx = make(1, -1, 1, {{0, LEQ, 1}}); //Options such that lx ≤ rx, and lz can be anything
- auto invalid_ranges_ly_ry = make(1, -1, -1, {{1, LESS, 0}}); //Options such that ly > ry, and rz can be anything
- ll misordered = calc_matches(invalid_ranges_lx_rx, invalid_ranges_ly_ry, false); //All solutions such that lx ≤ rx and ly > ry
- return misordered;
- }
- int main() {
- int n;
- cin >> n;
- a.resize(n + 1);
- for (int i = 1; i <= n; ++i) cin >> a[i];
- for (int i = 1; i <= n; ++i) a[i] += a[i - 1];
- ll full = step_1();
- run_test(full, true);
- ll misordered = step_2();
- full -= 3 * misordered; //Subtracting 3 times excludes (lx ≤ rx && ly > ry), (ly ≤ ry && lz > rz) and (lz ≤ rz && lx > rx)
- //Leaving only (lx > rx && ly > ry && lz > rz) or (lx ≤ rx && ly ≤ ry && lz ≤ rz)
- run_test(full, (lx > rx && ly > ry && lz > rz) or (lx <= rx && ly <= ry && lz <= rz));
- ll real_degenerate_any = calc_matches(make_ranges(true), make_ranges(false), false) * (n + 1ll);
- full -= 3 * real_degenerate_any;
- run_test(full, (lx > rx && ly > ry && lz > rz) or (lx < rx && ly < ry && lz < rz) or (lx == rx && ly == ry && lz == rz));
- full -= (n + 1) * 1ll * (n + 1) * 1ll * (n + 1);
- run_test(full, (lx > rx && ly > ry && lz > rz) or (lx < rx && ly < ry && lz < rz));
- full = safe_div(full, 2);
- run_test(full, (lx < rx && ly < ry && lz < rz));
- ll pairs = calc_matches(make_ranges(false, 2), make_ranges(false), false);
- full -= 3 * pairs;
- ll zero_ranges = make_ranges(false)[0];
- full += 2 * zero_ranges;
- full = safe_div(full, 6);
- run_test(full, (lx < rx && ly < ry && lz < rz) && pii(lx, rx) < pii(ly, ry) && pii(ly, ry) < pii(lz, rz));
- cout << full << endl;
- }
Advertisement
Add Comment
Please, Sign In to add comment