Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <set>
- #include <algorithm>
- #include <ctime>
- #include <cmath>
- #include <map>
- #include <assert.h>
- #include <fstream>
- #include <cstdlib>
- #include <random>
- #include <iomanip>
- using namespace std;
- #define sqr(a) ((a)*(a))
- #define all(a) (a).begin(), (a).end()
- const long long MOD = (long long) 1e9 + 7;
- const long long MAX_N = (long long) 100;
- long long binPow(long long a, long long b) {
- if (b == 0)
- return 1;
- long long ans = binPow(a, b / 2);
- ans = ans * ans % MOD;
- if (b % 2)
- ans = ans * a % MOD;
- return ans;
- }
- // make it understandable one day...
- namespace fft {
- typedef double dbl;
- struct num {
- dbl x, y;
- num() { x = y = 0; }
- num(dbl x_, dbl y_) : x(x_), y(y_) {}
- };
- inline num operator+(num a, num b) { return num(a.x + b.x, a.y + b.y); }
- inline num operator-(num a, num b) { return num(a.x - b.x, a.y - b.y); }
- inline num operator*(num a, num b) { return num(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
- inline num conj(num a) { return num(a.x, -a.y); }
- int base = 1;
- vector<num> roots = {{0, 0}, {1, 0}};
- vector<int> rev = {0, 1};
- const dbl PI = static_cast<dbl>(acosl(-1.0));
- void ensure_base(int nbase) {
- if (nbase <= base) {
- return;
- }
- rev.resize(1 << nbase);
- for (int i = 0; i < (1 << nbase); i++) {
- rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
- }
- roots.resize(1 << nbase);
- while (base < nbase) {
- dbl angle = 2 * PI / (1 << (base + 1));
- // num z(cos(angle), sin(angle));
- for (int i = 1 << (base - 1); i < (1 << base); i++) {
- roots[i << 1] = roots[i];
- // roots[(i << 1) + 1] = roots[i] * z;
- dbl angle_i = angle * (2 * i + 1 - (1 << base));
- roots[(i << 1) + 1] = num(cos(angle_i), sin(angle_i));
- }
- base++;
- }
- }
- void fft(vector<num>& a, int n = -1) {
- if (n == -1) {
- n = (int) a.size();
- }
- assert((n & (n - 1)) == 0);
- int zeros = __builtin_ctz(n);
- ensure_base(zeros);
- int shift = base - zeros;
- for (int i = 0; i < n; i++) {
- if (i < (rev[i] >> shift)) {
- swap(a[i], a[rev[i] >> shift]);
- }
- }
- for (int k = 1; k < n; k <<= 1) {
- for (int i = 0; i < n; i += 2 * k) {
- for (int j = 0; j < k; j++) {
- num z = a[i + j + k] * roots[j + k];
- a[i + j + k] = a[i + j] - z;
- a[i + j] = a[i + j] + z;
- }
- }
- }
- }
- vector<num> fa, fb;
- vector<int64_t> square(const vector<int>& a) {
- if (a.empty()) {
- return {};
- }
- int need = (int) a.size() + (int) a.size() - 1;
- int nbase = 1;
- while ((1 << nbase) < need) nbase++;
- ensure_base(nbase);
- int sz = 1 << nbase;
- if ((sz >> 1) > (int) fa.size()) {
- fa.resize(sz >> 1);
- }
- for (int i = 0; i < (sz >> 1); i++) {
- int x = (2 * i < (int) a.size() ? a[2 * i] : 0);
- int y = (2 * i + 1 < (int) a.size() ? a[2 * i + 1] : 0);
- fa[i] = num(x, y);
- }
- fft(fa, sz >> 1);
- num r(1.0 / (sz >> 1), 0.0);
- for (int i = 0; i <= (sz >> 2); i++) {
- int j = ((sz >> 1) - i) & ((sz >> 1) - 1);
- num fe = (fa[i] + conj(fa[j])) * num(0.5, 0);
- num fo = (fa[i] - conj(fa[j])) * num(0, -0.5);
- num aux = fe * fe + fo * fo * roots[(sz >> 1) + i] * roots[(sz >> 1) + i];
- num tmp = fe * fo;
- fa[i] = r * (conj(aux) + num(0, 2) * conj(tmp));
- fa[j] = r * (aux + num(0, 2) * tmp);
- }
- fft(fa, sz >> 1);
- vector<int64_t> res(need);
- for (int i = 0; i < need; i++) {
- res[i] = llround(i % 2 == 0 ? fa[i >> 1].x : fa[i >> 1].y);
- }
- return res;
- }
- vector<int64_t> multiply(const vector<int>& a, const vector<int>& b) {
- if (a.empty() || b.empty()) {
- return {};
- }
- if (a == b) {
- return square(a);
- }
- int need = (int) a.size() + (int) b.size() - 1;
- int nbase = 1;
- while ((1 << nbase) < need) nbase++;
- ensure_base(nbase);
- int sz = 1 << nbase;
- if (sz > (int) fa.size()) {
- fa.resize(sz);
- }
- for (int i = 0; i < sz; i++) {
- int x = (i < (int) a.size() ? a[i] : 0);
- int y = (i < (int) b.size() ? b[i] : 0);
- fa[i] = num(x, y);
- }
- fft(fa, sz);
- num r(0, -0.25 / (sz >> 1));
- for (int i = 0; i <= (sz >> 1); i++) {
- int j = (sz - i) & (sz - 1);
- num z = (fa[j] * fa[j] - conj(fa[i] * fa[i])) * r;
- fa[j] = (fa[i] * fa[i] - conj(fa[j] * fa[j])) * r;
- fa[i] = z;
- }
- for (int i = 0; i < (sz >> 1); i++) {
- num A0 = (fa[i] + fa[i + (sz >> 1)]) * num(0.5, 0);
- num A1 = (fa[i] - fa[i + (sz >> 1)]) * num(0.5, 0) * roots[(sz >> 1) + i];
- fa[i] = A0 + A1 * num(0, 1);
- }
- fft(fa, sz >> 1);
- vector<int64_t> res(need);
- for (int i = 0; i < need; i++) {
- res[i] = llround(i % 2 == 0 ? fa[i >> 1].x : fa[i >> 1].y);
- }
- return res;
- }
- }
- int eval(char c) {
- if (c == '?') {
- return 0;
- }
- if (c == 'o') {
- return 1;
- }
- if (c == 'k') {
- return -1;
- }
- throw;
- }
- const int maxn = 1e6;
- #define ll long long
- ll dp[maxn];
- int val[maxn];
- int cnto[maxn], cntk[maxn];
- void upmax(ll& x, ll y) {
- x = max(x, y);
- }
- int main() {
- // freopen("input.txt", "r", stdin);
- ios_base::sync_with_stdio(0);
- cin.tie(0);
- // srand(time(0));
- int o, k;
- cin >> o >> k;
- string s, p;
- cin >> s >> p;
- int n = s.size(), m = p.size();
- vector<int> a, b;
- a.resize(n);
- b.resize(m);
- for (int i = 0; i < n; i++) {
- a[n - i - 1] = eval(s[i]);
- cnto[i + 1] = cnto[i] + (s[i] == 'o');
- cntk[i + 1] = cntk[i] + (s[i] == 'k');
- }
- int cnt0 = 0;
- for (int i = 0; i < m; i++) {
- b[i] = eval(p[i]);
- cnt0 += b[i] == 0;
- }
- // for (int i = 0; i < n; i++) {
- // cout << a[i] << ' ';
- // }
- // cout << endl;
- // for (int i = 0; i < m; i++) {
- // cout << b[i] << ' ';
- // }
- auto c = fft::multiply(a, b);
- // cout << c[7];
- // return 0;
- for (int i = 0; i + m <= n; i++) {
- if (n - i - 1 >= c.size()) {
- throw;
- }
- // cout << i + n - 1 << ' ' << c[n - i - 1] << endl;
- int x = c[n - i - 1];
- int q = cnt0;
- val[i] = m - cnt0 - (x + m + cnt0) / 2;
- // cout << x << ' ' << q << ' ' << m << endl;
- // int y = x - m + q;
- // assert(y % 2 == 0);
- // val[i] = m - q - y / 2;
- // cout << val[i] << endl;
- }
- for (int i = 0; i < n; i++) {
- if (i + m <= n) {
- long long kek = 1ll * o * (cnto[i + m] - cnto[i]) + 1ll * k * (cntk[i + m] - cntk[i]);
- for (int j = 0; j < val[i] && kek > 0; j++) {
- kek /= 2;
- }
- upmax(dp[i + m], dp[i] + kek);
- }
- upmax(dp[i + 1], dp[i]);
- }
- cout << dp[n];
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement