Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <algorithm>
- #include <cmath>
- #include <list>
- #include <stack>
- #include <deque>
- #include <bitset>
- #include <set>
- using namespace std;
- typedef long long ll;
- const ll MOD1 = 1e9 + 7;
- ll ans = 1;
- ll n, k, kmin, cycle;
- bool ch;
- vector<int> g[1000001];
- int pred[1000001];
- int main() {
- bitset<1000001>vis;
- int i, j;
- cin >> n >> k;
- int x;
- for (i = 1; i <= n; i++) {
- cin >> x;
- if (x != i) {
- g[i].push_back(x);
- g[x].push_back(i);
- }
- }
- ll temp;
- int cur;
- int t;
- stack<int>st;
- for (i = 1; i <= n; i++) {
- if (!vis[i]) {
- ch = false;
- temp = 1;
- cycle = 0;
- kmin = 0;
- st.push(i);
- while (!st.empty()) {
- cur = st.top();
- st.pop();
- if (!vis[cur]) {
- kmin++;
- vis[cur] = true;
- for (auto y : g[cur]) {
- if(y!=pred[cur])
- pred[y] = cur;
- if (!vis[y]) {
- st.push(y);
- }
- else if (!ch && vis[y] && y != pred[cur]) {
- t = 0;
- ch = true;
- while (cur != y) {
- cur = pred[cur];
- t++;
- }
- cycle = t + 1;
- kmin -= (t + 1);
- }
- }
- }
- }
- if (cycle) {
- for (j = 0; j < cycle; j++) {
- temp = ((temp%MOD1)*((k - 1) % MOD1)) % MOD1;
- }
- if (cycle % 2 == 0) {
- temp += k - 1;
- }
- else {
- temp -= (k - 1);
- }
- if (temp < 0) {
- temp += MOD1;
- }
- temp %= MOD1;
- for (j = 0; j < kmin; j++) {
- temp = ((temp%MOD1)*((k - 1) % MOD1)) % MOD1;
- }
- }
- else {
- temp = ((temp%MOD1)*((k) % MOD1)) % MOD1;
- for (j = 0; j < kmin - 1; j++) {
- temp = ((temp%MOD1)*((k - 1)%MOD1))%MOD1;
- }
- }
- ans = ((temp%MOD1)*((ans) % MOD1)) % MOD1;
- }
- }
- cout << ans;
- //system("pause>nul");
- return 0;
- }
- /*
- 12 3
- 7 7 10 6 2 3 5 1 9 4 11 11
- */
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement