Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- #define all(x) begin(x),end(x)
- #define make_unique(x) if (!is_sorted(all(x))) sort(all(x)); x.erase(unique(all(x)), x.end());
- #define remax(x,y) (x > y ? x = y, true : false)
- #define remin(x,y) (x < y ? x = y, true : false)
- #define size(x) int(x.size())
- using namespace std;
- using ll = long long;
- struct SuffixArray {
- const string s;
- const int n;
- vector<int> arr;
- vector<int> lcp;
- #define modulo(x) (x >= n ? x - n : x)
- SuffixArray(const string& str)
- : s(str + char(0))
- , n(size(s))
- , arr(n)
- , lcp(n - 1)
- {
- vector<int> c(n), buff(n), ind(n), cnt(n);
- iota(all(arr), 0);
- vector<int> mp(256);
- for (int i = 0; i < n; i++) { mp[s[i]]++; }
- for (int i = 1; i < 256; i++) { mp[i] += mp[i - 1]; }
- for (int i = n - 1; i >= 0; i--) { arr[--mp[s[i]]] = i; }
- c[arr[0]] = 0;
- for (int i = 1; i < n; i++) {
- c[arr[i]] = c[arr[i - 1]];
- if (s[arr[i]] != s[arr[i - 1]]) {
- c[arr[i]]++;
- }
- }
- for (int deg = 1; (1 << (deg - 1)) < n; deg++) {
- const int len = 1 << (deg - 1);
- fill(all(cnt), 0);
- ind[0] = 0;
- for (int i = 0; i < n; i++) { arr[i] = modulo(arr[i] - len + n); }
- for (int i = 0; i < n; i++) { cnt[c[arr[i]]]++; }
- for (int i = 1; i < n; i++) { ind[i] = ind[i - 1] + cnt[i - 1]; }
- for (int i = 0; i < n; i++) { buff[ind[c[arr[i]]]++] = arr[i]; }
- copy(all(buff), begin(arr));
- buff[arr[0]] = 0;
- for (int i = 1; i < n; i++) {
- buff[arr[i]] = buff[arr[i - 1]];
- if (c[arr[i]] != c[arr[i - 1]] || c[modulo(arr[i] + len)] != c[modulo(arr[i - 1] + len)]) {
- buff[arr[i]]++;
- }
- }
- copy(all(buff), begin(c));
- }
- vector<int> pos(n);
- for (int i = 0; i < n; i++) {
- pos[arr[i]] = i;
- }
- int pre = 0;
- for (int i = 0; i < n - 1; i++) {
- int p = pos[i];
- int len = min(n - 1 - i, n - 1 - arr[p - 1]);
- int k = max(0, pre - 1);
- lcp[p - 1] = k;
- while (k < len && s[i + k] == s[arr[p - 1] + k]) { lcp[p - 1] = ++k; }
- pre = k;
- }
- }
- #undef modulo
- };
- template<typename T = int, T def = 0>
- struct SparseTable {
- int n, p;
- vector<vector<T>> t;
- vector<int> lg2;
- vector<int> buff;
- inline T Func(const T& __restrict a, const T& __restrict b) const {
- if (b > n || buff[a] <= buff[b])
- return a;
- return b;
- }
- SparseTable(const vector<T>& arr)
- : n(size(arr))
- , p(log2(n + 1) + 1)
- , t(p, vector<T>(n, def))
- , lg2(n + 1, -1)
- , buff(arr)
- {
- iota(all(t[0]), 0);
- for (int d = 1; d < p; d++) {
- const int dlt = 1 << (d - 1);
- for (int i = 0; i + dlt < n; i++) {
- t[d][i] = Func(t[d - 1][i], t[d - 1][i + dlt]);
- }
- }
- for (int i = 1; i <= n; i++) { lg2[i] = lg2[i >> 1] + 1; }
- }
- T operator () (int l, int r) const {
- if (l > r) return def;
- int d = lg2[r - l + 1];
- return Func(t[d][l], t[d][r - (1 << d) + 1]);
- }
- };
- ll Solve(const SparseTable<int, 1 << 25>& __restrict t, const vector<int>& __restrict arr, int l, int r, int val) {
- vector<tuple<int,int,int>> queries;
- queries.push_back({l, r, val});
- ll ans = 0;
- for (int i = 0; i < size(queries); i++) {
- const auto& [cl, cr, cval] = queries[i];
- l = cl;
- r = cr;
- val = cval;
- while (l <= r) {
- while (l <= r && arr[l] < val) { l++; }
- if (l > r) break;
- int cur = t(l, r);
- if (arr[cur] < val) {
- cur--;
- } else {
- cur = r;
- }
- int len = cur - l + 1;
- ans += 1ll * len * (len + 1) / 2;
- queries.push_back({l, cur, val + 1});
- l = cur + 1;
- }
- }
- return ans;
- }
- int main() {
- ios_base::sync_with_stdio(false);
- cin.tie(nullptr);
- cout.tie(nullptr);
- string s;
- cin >> s;
- SuffixArray suff(s);
- const int n = size(s);
- const auto& arr = suff.lcp;
- SparseTable<int, 1 << 25> t(arr);
- cout << Solve(t, arr, 0, n - 1, 0) << '\n';
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement