Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pragma GCC optimize("O3")
- #include <iostream>
- #include <vector>
- #include <algorithm>
- #include <cassert>
- #include <array>
- using namespace std;
- using ll = long long;
- const int N = 200001;
- const ll INF = (ll)1e18 + 228;
- vector<pair<int, int>> g[N];
- int w[N];
- int sz[N];
- int par_w[N];
- int par_v[N];
- void prepare(int v, int p) {
- if (p != -1) {
- auto it = g[v].begin();
- while (it->first != p) {
- it++;
- }
- par_v[v] = p;
- par_w[v] = it->second;
- g[v].erase(it);
- }
- sz[v] = 1;
- for (int i = 0; i < g[v].size(); i++) {
- int to = g[v][i].first;
- prepare(to, v);
- sz[v] += sz[to];
- if (sz[to] > sz[g[v][0].first]) {
- swap(g[v][0], g[v][i]);
- }
- }
- }
- // fit in single long long to reduce mem usage
- // 20 10 20 10
- struct From {
- int lhs_k;
- int lhs_m;
- int rhs_k;
- int rhs_m;
- };
- ll code(ll lhs_k, ll lhs_m, ll rhs_k, ll rhs_m) {
- return (lhs_k | (lhs_m << 20) | (rhs_k << 30) | (rhs_m << 50));
- }
- const ll HI_MASK = (1LL << 20) - 1;
- const ll LO_MASK = (1LL << 10) - 1;
- From decode(ll x) {
- int lhs_k = x & HI_MASK;
- x >>= 20;
- int lhs_m = x & LO_MASK;
- x >>= 10;
- int rhs_k = x & HI_MASK;
- x >>= 20;
- int rhs_m = x & LO_MASK;
- return { lhs_k, lhs_m, rhs_k, rhs_m };
- }
- struct universal_dp_state {
- vector<ll> dp[4]; // either [used_v][open_edge] or [left_used][right_used] or [edge_to_parent]
- vector<ll> from[4];
- universal_dp_state* lhs;
- universal_dp_state* rhs;
- };
- universal_dp_state* dps[N];
- universal_dp_state* heavy_root[N];
- universal_dp_state* create_empty_ptr() {
- // TODO: still may use static object pool.. objects are removed though
- universal_dp_state* res = new universal_dp_state();
- res->dp[0].push_back(0);
- res->dp[1].push_back(-INF);
- res->dp[2].push_back(-INF);
- res->dp[3].push_back(-INF);
- res->lhs = res->rhs = nullptr;
- return res;
- }
- bool bit(int x, int i) {
- return (x >> i & 1);
- }
- int get_mask(bool lhs, bool rhs) {
- int res = 0;
- if (lhs) {
- res += 1;
- }
- if (rhs) {
- res += 2;
- }
- return res;
- }
- void trim(vector<ll>& a) {
- while (a.size() > 1 && a.back() == -INF) {
- a.pop_back();
- }
- }
- void trim(vector<pair<ll, int>>& a) {
- while (a.size() > 1 && a.back().first == -INF) {
- a.pop_back();
- }
- }
- void trim(universal_dp_state* a) {
- for (int i = 0; i < 4; i++) {
- trim(a->dp[i]);
- int sz = a->dp[i].size();
- if (sz < a->from[i].size()) {
- a->from[i].resize(sz);
- }
- }
- }
- void add_delta(vector<pair<ll, int>>& dst, const vector<ll>& src, int i, int lhs_i) {
- dst.emplace_back(dst.back().first + src[i] - src[i - 1], lhs_i);
- }
- vector<pair<ll, int>> minkowski_sum(vector<ll>& lhs, vector<ll>& rhs) {
- trim(lhs);
- int lsz = lhs.size();
- if (lsz == 1 && lhs[0] == -INF) {
- return {};
- }
- trim(rhs);
- int rsz = rhs.size();
- if (rsz == 1 && rhs[0] == -INF) {
- return {};
- }
- vector<pair<ll, int>> res;
- res.reserve(lsz + rsz - 1);
- int i = 0;
- int j = 0;
- while (i < lsz && lhs[i] == -INF) {
- i++;
- res.emplace_back(-INF, 0);
- }
- while (j < rsz && rhs[j] == -INF) {
- j++;
- res.emplace_back(-INF, 0);
- }
- res.emplace_back(lhs[i] + rhs[j], i);
- i++;
- j++;
- int sum_sz = lsz + rsz;
- while (i + j < sum_sz) {
- if (j == rsz) {
- add_delta(res, lhs, i, i);
- i++;
- }
- else if (i == lsz) {
- add_delta(res, rhs, j, i - 1);
- j++;
- }
- else if (lhs[i] - lhs[i - 1] > rhs[j] - rhs[j - 1]) {
- add_delta(res, lhs, i, i);
- i++;
- }
- else {
- add_delta(res, rhs, j, i - 1);
- j++;
- }
- }
- trim(res);
- return res;
- }
- const int MEM_DEPTH = 1;
- universal_dp_state* merge_kids(universal_dp_state* lhs, universal_dp_state* rhs, int dep) {
- // used_lhs | used_rhs
- // open_lhs ^ open_rhs
- universal_dp_state* res = create_empty_ptr();
- if (dep <= MEM_DEPTH) {
- res->lhs = lhs;
- res->rhs = rhs;
- }
- for (int ml = 0; ml < 4; ml++) {
- for (int mr = 0; mr < 4; mr++) {
- bool open_lhs = bit(ml, 1);
- bool open_rhs = bit(mr, 1);
- int msk = get_mask(bit(ml, 0) || bit(mr, 0), open_lhs ^ open_rhs);
- vector<ll>& cur_dp = res->dp[msk];
- vector<ll>& cur_from = res->from[msk];
- vector<pair<ll, int>> candy = minkowski_sum(lhs->dp[ml], rhs->dp[mr]);
- int sz = candy.size();
- if (cur_dp.size() < sz) {
- cur_dp.resize(sz, -INF);
- cur_from.resize(sz);
- }
- int delta = (open_lhs && open_rhs);
- for (int i = 0; i + delta < sz; i++) {
- ll val = candy[i + delta].first;
- int lhs_k = candy[i + delta].second;
- if (val > cur_dp[i]) {
- cur_dp[i] = val;
- cur_from[i] = code(lhs_k, ml, i + delta - lhs_k, mr);
- }
- }
- }
- }
- trim(res);
- if (dep > MEM_DEPTH) {
- delete(lhs);
- delete(rhs);
- }
- return res;
- }
- universal_dp_state* init_light_leaf(int v, int w) {
- const auto& dp = dps[v]->dp;
- universal_dp_state* ret = create_empty_ptr();
- // as is max(free, used) -> [0][0]
- int mask_as_is = get_mask(false, false);
- int sz0 = dp[0].size();
- int sz1 = dp[1].size();
- int sz = max(sz0, sz1);
- ret->dp[mask_as_is].resize(sz, -INF);
- ret->from[mask_as_is].resize(sz);
- for (int i = 0; i < sz; i++) {
- ll val0 = (i < sz0 ? dp[0][i] : -INF);
- ll val1 = (i < sz1 ? dp[1][i] : -INF);
- if (val1 > val0) {
- ret->dp[mask_as_is][i] = val1;
- ret->from[mask_as_is][i] = code(i, 1, 0, 0);
- }
- else {
- ret->dp[mask_as_is][i] = val0;
- ret->from[mask_as_is][i] = code(i, 0, 0, 0);
- }
- }
- // +1 edge free -> [1][1]
- int mask_take_edge = get_mask(true, true);
- sz = dp[0].size() + 1;
- ret->dp[mask_take_edge].resize(sz, -INF);
- ret->from[mask_take_edge].resize(sz);
- for (int i = 1; i < sz; i++) {
- ll val = dp[0][i - 1];
- if (val != -INF) {
- ret->dp[mask_take_edge][i] = val - w;
- ret->from[mask_take_edge][i] = code(i - 1, 0, 0, 0);
- }
- }
- trim(ret);
- return ret;
- }
- universal_dp_state* dq_kids(int l, int r, const vector<pair<int, int>>& kids, int dep) {
- if (l == r) {
- // init state of one child
- int v = kids[l].first;
- int w = kids[l].second;
- return init_light_leaf(v, w);
- }
- int m = (l + r) >> 1;
- universal_dp_state* left = dq_kids(l, m, kids, dep + 1);
- universal_dp_state* right = dq_kids(m + 1, r, kids, dep + 1);
- return merge_kids(left, right, dep);
- }
- universal_dp_state* merge_heavy(universal_dp_state* lhs, universal_dp_state* rhs, bool single_lhs, bool single_rhs, int dep) {
- universal_dp_state* res = create_empty_ptr();
- if (dep <= MEM_DEPTH) {
- res->lhs = lhs;
- res->rhs = rhs;
- }
- for (int ml = 0; ml < 4; ml++) {
- for (int mr = 0; mr < 4; mr++) {
- if (bit(ml, 1) && bit(mr, 0)) {
- continue;
- }
- int msk = get_mask(bit(ml, 0), bit(mr, 1));
- auto& cur_dp = res->dp[msk];
- auto& cur_from = res->from[msk];
- vector<pair<ll, int>> candy = minkowski_sum(lhs->dp[ml], rhs->dp[mr]);
- int sz = candy.size();
- if (cur_dp.size() < sz) {
- cur_dp.resize(sz, -INF);
- cur_from.resize(sz);
- }
- for (int i = 0; i < sz; i++) {
- ll val = candy[i].first;
- int lhs_k = candy[i].second;
- if (val > cur_dp[i]) {
- cur_dp[i] = val;
- cur_from[i] = code(lhs_k, ml, i - lhs_k, mr);
- }
- }
- }
- }
- trim(res);
- if (dep > MEM_DEPTH) {
- if (!single_lhs) {
- delete(lhs);
- }
- if (!single_rhs) {
- delete(rhs);
- }
- }
- return res;
- }
- universal_dp_state* dq_heavy(int l, int r, const vector<int>& path, int dep) {
- if (l == r) {
- return dps[path[l]]->rhs; // heavy leaf
- }
- int m = (l + r) >> 1;
- universal_dp_state* left = dq_heavy(l, m, path, dep + 1);
- universal_dp_state* right = dq_heavy(m + 1, r, path, dep + 1);
- return merge_heavy(left, right, (l == m), (m + 1 == r), dep);
- }
- void recalc_dp_heavy(int v) {
- vector<int> heavy_path = { v };
- int cur = v;
- while (!g[cur].empty()) {
- cur = g[cur][0].first;
- heavy_path.push_back(cur);
- }
- // fix dp[v] for v in path to be used in dq_heavy
- for (int u : heavy_path) {
- // сейчас там лежит [used][open]
- // дополнительно можем взять ребро вниз в тяжелого сына и ребро вверх
- // это влияет на used и open (последний должен стать 0). далее если used true, добавить вес вершины
- universal_dp_state* new_state = create_empty_ptr();
- for (int take_par = 0; take_par < 2; take_par++) {
- if (u == 0 && take_par) {
- continue;
- }
- for (int take_heavy = 0; take_heavy < 2; take_heavy++) {
- if (g[u].empty() && take_heavy) {
- continue;
- }
- for (int state = 0; state < 4; state++) {
- int used = bit(state, 0);
- int open = bit(state, 1);
- ll delta_w = 0;
- int delta_cnt = 0;
- if (take_par) {
- open ^= 1;
- delta_cnt++;
- delta_w -= par_w[u];
- used = 1;
- }
- if (take_heavy) {
- open ^= 1;
- delta_cnt++;
- delta_w -= g[u][0].second;
- used = 1;
- }
- if (open) {
- continue;
- }
- if (used) {
- delta_w += w[u];
- }
- delta_cnt >>= 1;
- auto& old_dp = dps[u]->lhs->dp[state]; // light root
- int sz = old_dp.size() + delta_cnt;
- int new_msk = get_mask(take_par, take_heavy);
- auto& new_dp = new_state->dp[new_msk];
- auto& new_from = new_state->from[new_msk];
- if (new_dp.size() < sz) {
- new_dp.resize(sz, -INF);
- new_from.resize(sz);
- }
- for (int i = sz - 1; i >= delta_cnt; i--) {
- ll val = old_dp[i - delta_cnt];
- if (val != -INF) {
- val += delta_w;
- if (val > new_dp[i]) {
- new_dp[i] = val;
- new_from[i] = code(i - delta_cnt, state, take_par, take_heavy);
- }
- }
- }
- }
- }
- }
- trim(new_state);
- dps[u]->rhs = new_state;
- }
- universal_dp_state* res = dq_heavy(0, (int)heavy_path.size() - 1, heavy_path, 0);
- heavy_root[v] = res;
- // fill dp[v]
- for (int msk = 0; msk < 4; msk++) {
- const auto& old_dp = res->dp[msk];
- int sz = old_dp.size();
- int new_msk = get_mask(bit(msk, 0), 0);
- auto& cur_dp = dps[v]->dp[new_msk];
- auto& cur_from = dps[v]->from[new_msk];
- if (cur_dp.size() < sz) {
- cur_dp.resize(sz, -INF);
- cur_from.resize(sz);
- }
- for (int i = 0; i < sz; i++) {
- if (old_dp[i] > cur_dp[i]) {
- cur_dp[i] = old_dp[i];
- cur_from[i] = code(i, msk, 0, 0);
- }
- }
- }
- trim(dps[v]);
- }
- void solve(int v) {
- dps[v] = create_empty_ptr();
- for (auto e : g[v]) {
- solve(e.first);
- }
- for (int i = 1; i < g[v].size(); i++) {
- // запускает разделяйку на тяжёлых путях, начинающихся в детях; пересчитывает значение dp
- recalc_dp_heavy(g[v][i].first);
- }
- if (g[v].size() > 1) {
- // мержит **уже правильно посчитанные** дпшки лёгких детей.
- // В текущей вершине оставляет значения as is, исправить его должен будет пересчёт на соответствующем тяжёлом пути
- dps[v]->lhs = dq_kids(1, (int)g[v].size() - 1, g[v], 0);
- }
- else {
- dps[v]->lhs = create_empty_ptr();
- }
- }
- vector<int> ans[N];
- void restore_heavy_root(int, int, int);
- void dq_light_restore(int l, int r, const vector<pair<int, int>>& g, int k, int msk, universal_dp_state* node) {
- if (l == r) {
- int v = g[l].first;
- int w = g[l].second;
- universal_dp_state* cur_node = (node != nullptr ? node : init_light_leaf(v, w));
- From from = decode(cur_node->from[msk][k]);
- delete(cur_node);
- if (from.lhs_k != k) {
- ans[par_v[v]].push_back(v);
- }
- if (from.lhs_k > 0) {
- restore_heavy_root(v, from.lhs_k, from.lhs_m);
- }
- return;
- }
- universal_dp_state* cur_root = (node != nullptr ? node : dq_kids(l, r, g, 0));
- universal_dp_state* lhs = cur_root->lhs;
- universal_dp_state* rhs = cur_root->rhs;
- From from = decode(cur_root->from[msk][k]);
- delete(cur_root);
- int m = (l + r) >> 1;
- if (from.lhs_k > 0) {
- dq_light_restore(l, m, g, from.lhs_k, from.lhs_m, lhs);
- }
- if (from.rhs_k > 0) {
- dq_light_restore(m + 1, r, g, from.rhs_k, from.rhs_m, rhs);
- }
- }
- void dq_heavy_restore(int l, int r, const vector<int>& path, int k, int msk, universal_dp_state* node) {
- if (l == r) {
- int v = path[l];
- universal_dp_state* src = dps[v]->rhs;
- From from = decode(src->from[msk][k]);
- if (from.rhs_k == 1) {
- ans[v].push_back(par_v[v]);
- }
- if (from.rhs_m == 1) {
- ans[v].push_back(g[v][0].first);
- }
- if (from.lhs_k > 0) {
- dq_light_restore(1, (int)g[v].size() - 1, g[v], from.lhs_k, from.lhs_m, dps[v]->lhs);
- }
- return;
- }
- universal_dp_state* cur_root = (node != nullptr ? node : dq_heavy(l, r, path, 0));
- universal_dp_state* lhs = cur_root->lhs;
- universal_dp_state* rhs = cur_root->rhs;
- From from = decode(cur_root->from[msk][k]);
- delete(cur_root);
- int m = (l + r) >> 1;
- if (from.lhs_k > 0) {
- dq_heavy_restore(l, m, path, from.lhs_k, from.lhs_m, lhs);
- }
- if (from.rhs_k > 0) {
- dq_heavy_restore(m + 1, r, path, from.rhs_k, from.rhs_m, rhs);
- }
- }
- void restore_heavy_root(int v, int k, int msk) {
- From from = decode(dps[v]->from[msk][k]);
- vector<int> heavy_path = { v };
- int cur = v;
- while (!g[cur].empty()) {
- cur = g[cur][0].first;
- heavy_path.push_back(cur);
- }
- dq_heavy_restore(0, (int)heavy_path.size() - 1, heavy_path, from.lhs_k, from.lhs_m, heavy_root[v]);
- }
- void work_work(int n, int k, int t) {
- for (int i = 0; i < n; i++) {
- cin >> w[i];
- }
- for (int i = 0; i < n - 1; i++) {
- int u, v, s;
- cin >> u >> v >> s;
- //u = i + 1, v = i + 2, s = 0;
- --u; --v;
- g[u].emplace_back(v, s);
- g[v].emplace_back(u, s);
- }
- prepare(0, -1);
- solve(0);
- recalc_dp_heavy(0);
- cout << dps[0]->dp[0][k] << '\n';
- if (t == 1) {
- restore_heavy_root(0, k, 0);
- for (int u = 0; u < n; u++) {
- for (int i = 0; i < ans[u].size(); i += 2) {
- cout << u + 1 << ' ' << ans[u][i] + 1 << ' ' << ans[u][i + 1] + 1 << '\n';
- }
- }
- }
- }
- int main() {
- ios::sync_with_stdio(false);
- cin.tie(nullptr);
- int n, k, t;
- cin >> n >> k >> t;
- work_work(n, k, t);
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement