Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- using namespace std;
- using ll = long long;
- constexpr int N = 1e3 + 5;
- constexpr ll mod = 998244353;
- string x, y;
- int n, m;
- ll f[N][N][2];
- void Read()
- {
- cin >> x >> y;
- m = x.size();
- n = y.size();
- x = " " + x;
- y = " " + y;
- }
- void Solve()
- {
- ll ans(0);
- for (int i = 0; i <= m; ++i)
- for (int j = 0; j <= n; ++j)
- if ((i == 0) + (j == 0) < 2)
- {
- // t = 0
- if (i != 0)
- {
- if (j == 0)
- (++f[i][j][0]) %= mod;
- else if (x[i] != y[j])
- (f[i][j][0] += f[0][j][1]) %= mod;
- if (i < m && x[i] != x[i + 1])
- (f[i + 1][j][0] += f[i][j][0]) %= mod;
- if (j < n && j != 0 && x[i] != y[j + 1])
- (f[i][j + 1][1] += f[i][j][0]) %= mod;
- }
- // t = 1
- if (j != 0)
- {
- if (i == 0)
- (++f[i][j][1]) %= mod;
- else if (x[i] != y[j])
- (f[i][j][1] += f[i][0][0]) %= mod;
- if (i < m && i != 0 && y[j] != x[i + 1])
- (f[i + 1][j][0] += f[i][j][1]) %= mod;
- if (j < n && y[j] != y[j + 1])
- (f[i][j + 1][1] += f[i][j][1]) %= mod;
- }
- //cout << i << " " << j << ": " << f[i][j][0] << " " << f[i][j][1] << '\n';
- if (i != 0 && j != 0)
- (ans += f[i][j][0] + f[i][j][1]) %= mod;
- }
- cout << ans;
- }
- int32_t main()
- {
- ios::sync_with_stdio(0);
- cin.tie(0);
- cout.tie(0);
- Read();
- Solve();
- }
Advertisement
Add Comment
Please, Sign In to add comment