Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- #include <ext/pb_ds/assoc_container.hpp>
- #include <ext/pb_ds/tree_policy.hpp>
- #include <ext/rope>
- #define ll long long
- #define ll128 __uint128_t
- #define ld long double
- #define vll vector <ll>
- #define vvll vector <vll>
- #define pll pair <ll, ll>
- #define rep(i, a, b) for(ll i = a; i < b; i++)
- #define per(i, a, b) for(ll i = a - 1; i >= b; --i)
- #define endl "\n"
- #define pb push_back
- #define pf push_front
- #define all(v) (v).begin(), (v).end()
- #define rall(v) (v).rbegin(), (v).rend()
- #define sorta(v) sort(all(v))
- #define sortd(v) sort(rall(v))
- #define vld vector<ld>
- #define debug if (1)
- #define log(val) debug {cout << "\n" << #val << ": " << val << "\n";}
- #define ios ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
- #define mod (ll)(1e9 + 7)
- using namespace std;
- using namespace __gnu_cxx;
- using namespace __gnu_pbds;
- ostream & operator << (ostream & out, vll & a) {
- for(auto i : a) out << i << " ";
- return out;
- }
- istream & operator >> (istream & in, vll & a) {
- for(auto &i : a) in >> i;
- return in;
- }
- const ll N = 3 * 1e5 + 10;
- vll g[N];
- ll tin[N], tout[N];
- const ll logn = 20;
- ll up[N][logn];
- ll timer = 1;
- ll h[N];
- void dfs(ll v = 0, ll p = 0, ll curh = 1) {
- tin[v] = timer++;
- h[v] = curh;
- up[v][0] = p;
- rep(i, 1, logn + 1) {
- up[v][i] = up[up[v][i - 1]][i - 1];
- }
- for(auto i : g[v]) {
- if(i != p) {
- dfs(i, v, curh + 1);
- }
- }
- tout[v] = timer++;
- }
- bool parent(ll a, ll b) {
- return (tin[a] <= tin[b] && tout[a] >= tout[b]);
- }
- ll lca(ll a, ll b) {
- if(parent(a, b)) return a;
- if(parent(b, a)) return b;
- for(int i = logn; i >= 0; i--) {
- if(!parent(up[a][i], b)) a = up[a][i];
- }
- return up[a][0];
- }
- int main() {
- // freopen("input.txt","r", stdin);
- // freopen("output.txt","w", stdout);
- ll n, k;
- cin >> n >> k;
- vll a(n + 1);
- rep(i, 1, n + 1) {
- cin >> a[i];
- }
- for (auto &i : a) i--;
- rep(i, 0, n - 1) {
- ll v1, v2;
- cin >> v1 >> v2;
- g[v1 - 1].pb(v2 - 1);
- g[v2 - 1].pb(v1 - 1);
- }
- dfs(0);
- vvll dp(n + 1, vll(k + 1, 1e18));
- if (n == 1) {
- return cout << h[a[1]], 0;
- }
- dp[0][0] = 0;
- rep (i, 0, n) {
- // cout << i << endl;
- rep(j, 0, k) {
- dp[i + 1][j] = min(dp[i + 1][j], dp[i][j]);
- dp[i + 1][j + 1] = min(dp[i + 1][j + 1], dp[i][j] + h[a[i + 1]]);
- if (i + 2 <= n) {
- dp[i + 2][j + 1] = min(dp[i + 2][j + 1], dp[i][j] + h[lca(a[i + 1], a[i + 2])]);
- }
- }
- }
- /*
- rep(i, 0, n + 1) {
- rep(j, 0, k + 1) {
- cout << dp[i][j] << " ";
- }
- cout << endl;
- }
- */
- cout << dp[n][k] << endl;
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement