Advertisement
Guest User

Упячка

a guest
Jan 21st, 2024
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 13.96 KB | None | 0 0
  1. #pragma GCC optimize("O3")
  2.  
  3. #include <iostream>
  4. #include <vector>
  5. #include <algorithm>
  6. #include <cassert>
  7. #include <array>
  8.  
  9. using namespace std;
  10. using ll = long long;
  11.  
  12. const int N = 200001;
  13. const ll INF = (ll)1e18 + 228;
  14.  
  15. vector<pair<int, int>> g[N];
  16. int w[N];
  17. int sz[N];
  18. int par_w[N];
  19. int par_v[N];
  20.  
  21. void prepare(int v, int p) {
  22.     if (p != -1) {
  23.         auto it = g[v].begin();
  24.         while (it->first != p) {
  25.             it++;
  26.         }
  27.         par_v[v] = p;
  28.         par_w[v] = it->second;
  29.         g[v].erase(it);
  30.     }
  31.     sz[v] = 1;
  32.     for (int i = 0; i < g[v].size(); i++) {
  33.         int to = g[v][i].first;
  34.         prepare(to, v);
  35.         sz[v] += sz[to];
  36.         if (sz[to] > sz[g[v][0].first]) {
  37.             swap(g[v][0], g[v][i]);
  38.         }
  39.     }
  40. }
  41.  
  42. // fit in single long long to reduce mem usage
  43. // 20 10 20 10
  44. struct From {
  45.     int lhs_k;
  46.     int lhs_m;
  47.     int rhs_k;
  48.     int rhs_m;
  49. };
  50.  
  51. ll code(ll lhs_k, ll lhs_m, ll rhs_k, ll rhs_m) {
  52.     return (lhs_k | (lhs_m << 20) | (rhs_k << 30) | (rhs_m << 50));
  53. }
  54.  
  55. const ll HI_MASK = (1LL << 20) - 1;
  56. const ll LO_MASK = (1LL << 10) - 1;
  57.  
  58. From decode(ll x) {
  59.     int lhs_k = x & HI_MASK;
  60.     x >>= 20;
  61.     int lhs_m = x & LO_MASK;
  62.     x >>= 10;
  63.     int rhs_k = x & HI_MASK;
  64.     x >>= 20;
  65.     int rhs_m = x & LO_MASK;
  66.     return { lhs_k, lhs_m, rhs_k, rhs_m };
  67. }
  68.  
  69. struct universal_dp_state {
  70.     vector<ll> dp[4]; // either [used_v][open_edge] or [left_used][right_used] or [edge_to_parent]
  71.     vector<ll> from[4];
  72.     universal_dp_state* lhs;
  73.     universal_dp_state* rhs;
  74. };
  75.  
  76. universal_dp_state* dps[N];
  77. universal_dp_state* heavy_root[N];
  78.  
  79. universal_dp_state* create_empty_ptr() {
  80.     // TODO: still may use static object pool.. objects are removed though
  81.     universal_dp_state* res = new universal_dp_state();
  82.     res->dp[0].push_back(0);
  83.     res->dp[1].push_back(-INF);
  84.     res->dp[2].push_back(-INF);
  85.     res->dp[3].push_back(-INF);
  86.     res->lhs = res->rhs = nullptr;
  87.     return res;
  88. }
  89.  
  90. bool bit(int x, int i) {
  91.     return (x >> i & 1);
  92. }
  93.  
  94. int get_mask(bool lhs, bool rhs) {
  95.     int res = 0;
  96.     if (lhs) {
  97.         res += 1;
  98.     }
  99.     if (rhs) {
  100.         res += 2;
  101.     }
  102.     return res;
  103. }
  104.  
  105. void trim(vector<ll>& a) {
  106.     while (a.size() > 1 && a.back() == -INF) {
  107.         a.pop_back();
  108.     }
  109. }
  110.  
  111. void trim(vector<pair<ll, int>>& a) {
  112.     while (a.size() > 1 && a.back().first == -INF) {
  113.         a.pop_back();
  114.     }
  115. }
  116.  
  117. void trim(universal_dp_state* a) {
  118.     for (int i = 0; i < 4; i++) {
  119.         trim(a->dp[i]);
  120.         int sz = a->dp[i].size();
  121.         if (sz < a->from[i].size()) {
  122.             a->from[i].resize(sz);
  123.         }
  124.     }
  125. }
  126.  
  127. void add_delta(vector<pair<ll, int>>& dst, const vector<ll>& src, int i, int lhs_i) {
  128.     dst.emplace_back(dst.back().first + src[i] - src[i - 1], lhs_i);
  129. }
  130.  
  131. vector<pair<ll, int>> minkowski_sum(vector<ll>& lhs, vector<ll>& rhs) {
  132.  
  133.     trim(lhs);
  134.     int lsz = lhs.size();
  135.     if (lsz == 1 && lhs[0] == -INF) {
  136.         return {};
  137.     }
  138.  
  139.     trim(rhs);
  140.     int rsz = rhs.size();
  141.     if (rsz == 1 && rhs[0] == -INF) {
  142.         return {};
  143.     }
  144.  
  145.     vector<pair<ll, int>> res;
  146.     res.reserve(lsz + rsz - 1);
  147.  
  148.     int i = 0;
  149.     int j = 0;
  150.  
  151.     while (i < lsz && lhs[i] == -INF) {
  152.         i++;
  153.         res.emplace_back(-INF, 0);
  154.     }
  155.     while (j < rsz && rhs[j] == -INF) {
  156.         j++;
  157.         res.emplace_back(-INF, 0);
  158.     }
  159.     res.emplace_back(lhs[i] + rhs[j], i);
  160.     i++;
  161.     j++;
  162.  
  163.     int sum_sz = lsz + rsz;
  164.  
  165.     while (i + j < sum_sz) {
  166.         if (j == rsz) {
  167.             add_delta(res, lhs, i, i);
  168.             i++;
  169.         }
  170.         else if (i == lsz) {
  171.             add_delta(res, rhs, j, i - 1);
  172.             j++;
  173.         }
  174.         else if (lhs[i] - lhs[i - 1] > rhs[j] - rhs[j - 1]) {
  175.             add_delta(res, lhs, i, i);
  176.             i++;
  177.         }
  178.         else {
  179.             add_delta(res, rhs, j, i - 1);
  180.             j++;
  181.         }
  182.     }
  183.  
  184.     trim(res);
  185.     return res;
  186. }
  187.  
  188. const int MEM_DEPTH = 1;
  189.  
  190. universal_dp_state* merge_kids(universal_dp_state* lhs, universal_dp_state* rhs, int dep) {
  191.  
  192.     // used_lhs | used_rhs
  193.     // open_lhs ^ open_rhs
  194.     universal_dp_state* res = create_empty_ptr();
  195.     if (dep <= MEM_DEPTH) {
  196.         res->lhs = lhs;
  197.         res->rhs = rhs;
  198.     }
  199.  
  200.     for (int ml = 0; ml < 4; ml++) {
  201.         for (int mr = 0; mr < 4; mr++) {
  202.  
  203.             bool open_lhs = bit(ml, 1);
  204.             bool open_rhs = bit(mr, 1);
  205.  
  206.             int msk = get_mask(bit(ml, 0) || bit(mr, 0), open_lhs ^ open_rhs);
  207.             vector<ll>& cur_dp = res->dp[msk];
  208.             vector<ll>& cur_from = res->from[msk];
  209.  
  210.             vector<pair<ll, int>> candy = minkowski_sum(lhs->dp[ml], rhs->dp[mr]);
  211.             int sz = candy.size();
  212.  
  213.             if (cur_dp.size() < sz) {
  214.                 cur_dp.resize(sz, -INF);
  215.                 cur_from.resize(sz);
  216.             }
  217.  
  218.             int delta = (open_lhs && open_rhs);
  219.  
  220.             for (int i = 0; i + delta < sz; i++) {
  221.  
  222.                 ll val = candy[i + delta].first;
  223.                 int lhs_k = candy[i + delta].second;
  224.  
  225.                 if (val > cur_dp[i]) {
  226.                     cur_dp[i] = val;
  227.                     cur_from[i] = code(lhs_k, ml, i + delta - lhs_k, mr);
  228.                 }
  229.             }
  230.         }
  231.     }
  232.  
  233.     trim(res);
  234.     if (dep > MEM_DEPTH) {
  235.         delete(lhs);
  236.         delete(rhs);
  237.     }
  238.     return res;
  239. }
  240.  
  241. universal_dp_state* init_light_leaf(int v, int w) {
  242.     const auto& dp = dps[v]->dp;
  243.     universal_dp_state* ret = create_empty_ptr();
  244.  
  245.     // as is max(free, used) -> [0][0]
  246.     int mask_as_is = get_mask(false, false);
  247.     int sz0 = dp[0].size();
  248.     int sz1 = dp[1].size();
  249.     int sz = max(sz0, sz1);
  250.     ret->dp[mask_as_is].resize(sz, -INF);
  251.     ret->from[mask_as_is].resize(sz);
  252.     for (int i = 0; i < sz; i++) {
  253.         ll val0 = (i < sz0 ? dp[0][i] : -INF);
  254.         ll val1 = (i < sz1 ? dp[1][i] : -INF);
  255.         if (val1 > val0) {
  256.             ret->dp[mask_as_is][i] = val1;
  257.             ret->from[mask_as_is][i] = code(i, 1, 0, 0);
  258.         }
  259.         else {
  260.             ret->dp[mask_as_is][i] = val0;
  261.             ret->from[mask_as_is][i] = code(i, 0, 0, 0);
  262.         }
  263.     }
  264.  
  265.     // +1 edge free -> [1][1]
  266.     int mask_take_edge = get_mask(true, true);
  267.     sz = dp[0].size() + 1;
  268.     ret->dp[mask_take_edge].resize(sz, -INF);
  269.     ret->from[mask_take_edge].resize(sz);
  270.     for (int i = 1; i < sz; i++) {
  271.         ll val = dp[0][i - 1];
  272.         if (val != -INF) {
  273.             ret->dp[mask_take_edge][i] = val - w;
  274.             ret->from[mask_take_edge][i] = code(i - 1, 0, 0, 0);
  275.         }
  276.     }
  277.  
  278.     trim(ret);
  279.     return ret;
  280. }
  281.  
  282. universal_dp_state* dq_kids(int l, int r, const vector<pair<int, int>>& kids, int dep) {
  283.     if (l == r) {
  284.         // init state of one child
  285.         int v = kids[l].first;
  286.         int w = kids[l].second;
  287.         return init_light_leaf(v, w);
  288.     }
  289.  
  290.     int m = (l + r) >> 1;
  291.     universal_dp_state* left = dq_kids(l, m, kids, dep + 1);
  292.     universal_dp_state* right = dq_kids(m + 1, r, kids, dep + 1);
  293.     return merge_kids(left, right, dep);
  294. }
  295.  
  296. universal_dp_state* merge_heavy(universal_dp_state* lhs, universal_dp_state* rhs, bool single_lhs, bool single_rhs, int dep) {
  297.  
  298.     universal_dp_state* res = create_empty_ptr();
  299.     if (dep <= MEM_DEPTH) {
  300.         res->lhs = lhs;
  301.         res->rhs = rhs;
  302.     }
  303.  
  304.     for (int ml = 0; ml < 4; ml++) {
  305.         for (int mr = 0; mr < 4; mr++) {
  306.  
  307.             if (bit(ml, 1) && bit(mr, 0)) {
  308.                 continue;
  309.             }
  310.  
  311.             int msk = get_mask(bit(ml, 0), bit(mr, 1));
  312.             auto& cur_dp = res->dp[msk];
  313.             auto& cur_from = res->from[msk];
  314.  
  315.             vector<pair<ll, int>> candy = minkowski_sum(lhs->dp[ml], rhs->dp[mr]);
  316.             int sz = candy.size();
  317.  
  318.             if (cur_dp.size() < sz) {
  319.                 cur_dp.resize(sz, -INF);
  320.                 cur_from.resize(sz);
  321.             }
  322.  
  323.             for (int i = 0; i < sz; i++) {
  324.  
  325.                 ll val = candy[i].first;
  326.                 int lhs_k = candy[i].second;
  327.  
  328.                 if (val > cur_dp[i]) {
  329.                     cur_dp[i] = val;
  330.                     cur_from[i] = code(lhs_k, ml, i - lhs_k, mr);
  331.                 }
  332.             }
  333.         }
  334.     }
  335.     trim(res);
  336.     if (dep > MEM_DEPTH) {
  337.         if (!single_lhs) {
  338.             delete(lhs);
  339.         }
  340.         if (!single_rhs) {
  341.             delete(rhs);
  342.         }
  343.     }
  344.     return res;
  345. }
  346.  
  347. universal_dp_state* dq_heavy(int l, int r, const vector<int>& path, int dep) {
  348.     if (l == r) {
  349.         return dps[path[l]]->rhs; // heavy leaf
  350.     }
  351.  
  352.     int m = (l + r) >> 1;
  353.     universal_dp_state* left = dq_heavy(l, m, path, dep + 1);
  354.     universal_dp_state* right = dq_heavy(m + 1, r, path, dep + 1);
  355.     return merge_heavy(left, right, (l == m), (m + 1 == r), dep);
  356. }
  357.  
  358. void recalc_dp_heavy(int v) {
  359.     vector<int> heavy_path = { v };
  360.     int cur = v;
  361.     while (!g[cur].empty()) {
  362.         cur = g[cur][0].first;
  363.         heavy_path.push_back(cur);
  364.     }
  365.  
  366.     // fix dp[v] for v in path to be used in dq_heavy
  367.     for (int u : heavy_path) {
  368.         // сейчас там лежит [used][open]
  369.         // дополнительно можем взять ребро вниз в тяжелого сына и ребро вверх
  370.         // это влияет на used и open (последний должен стать 0). далее если used true, добавить вес вершины
  371.  
  372.         universal_dp_state* new_state = create_empty_ptr();
  373.  
  374.         for (int take_par = 0; take_par < 2; take_par++) {
  375.             if (u == 0 && take_par) {
  376.                 continue;
  377.             }
  378.             for (int take_heavy = 0; take_heavy < 2; take_heavy++) {
  379.                 if (g[u].empty() && take_heavy) {
  380.                     continue;
  381.                 }
  382.                 for (int state = 0; state < 4; state++) {
  383.  
  384.                     int used = bit(state, 0);
  385.                     int open = bit(state, 1);
  386.  
  387.                     ll delta_w = 0;
  388.                     int delta_cnt = 0;
  389.  
  390.                     if (take_par) {
  391.                         open ^= 1;
  392.                         delta_cnt++;
  393.                         delta_w -= par_w[u];
  394.                         used = 1;
  395.                     }
  396.                     if (take_heavy) {
  397.                         open ^= 1;
  398.                         delta_cnt++;
  399.                         delta_w -= g[u][0].second;
  400.                         used = 1;
  401.                     }
  402.                     if (open) {
  403.                         continue;
  404.                     }
  405.                     if (used) {
  406.                         delta_w += w[u];
  407.                     }
  408.  
  409.                     delta_cnt >>= 1;
  410.                     auto& old_dp = dps[u]->lhs->dp[state]; // light root
  411.                     int sz = old_dp.size() + delta_cnt;
  412.                     int new_msk = get_mask(take_par, take_heavy);
  413.                     auto& new_dp = new_state->dp[new_msk];
  414.                     auto& new_from = new_state->from[new_msk];
  415.  
  416.                     if (new_dp.size() < sz) {
  417.                         new_dp.resize(sz, -INF);
  418.                         new_from.resize(sz);
  419.                     }
  420.  
  421.                     for (int i = sz - 1; i >= delta_cnt; i--) {
  422.                         ll val = old_dp[i - delta_cnt];
  423.                         if (val != -INF) {
  424.                             val += delta_w;
  425.                             if (val > new_dp[i]) {
  426.                                 new_dp[i] = val;
  427.                                 new_from[i] = code(i - delta_cnt, state, take_par, take_heavy);
  428.                             }
  429.                         }
  430.                     }
  431.                 }
  432.             }
  433.         }
  434.  
  435.         trim(new_state);
  436.         dps[u]->rhs = new_state;
  437.     }
  438.  
  439.     universal_dp_state* res = dq_heavy(0, (int)heavy_path.size() - 1, heavy_path, 0);
  440.     heavy_root[v] = res;
  441.  
  442.     // fill dp[v]
  443.     for (int msk = 0; msk < 4; msk++) {
  444.  
  445.         const auto& old_dp = res->dp[msk];
  446.         int sz = old_dp.size();
  447.  
  448.         int new_msk = get_mask(bit(msk, 0), 0);
  449.         auto& cur_dp = dps[v]->dp[new_msk];
  450.         auto& cur_from = dps[v]->from[new_msk];
  451.  
  452.         if (cur_dp.size() < sz) {
  453.             cur_dp.resize(sz, -INF);
  454.             cur_from.resize(sz);
  455.         }
  456.  
  457.         for (int i = 0; i < sz; i++) {
  458.             if (old_dp[i] > cur_dp[i]) {
  459.                 cur_dp[i] = old_dp[i];
  460.                 cur_from[i] = code(i, msk, 0, 0);
  461.             }
  462.         }
  463.     }
  464.  
  465.     trim(dps[v]);
  466. }
  467.  
  468. void solve(int v) {
  469.  
  470.     dps[v] = create_empty_ptr();
  471.  
  472.     for (auto e : g[v]) {
  473.         solve(e.first);
  474.     }
  475.  
  476.     for (int i = 1; i < g[v].size(); i++) {
  477.         // запускает разделяйку на тяжёлых путях, начинающихся в детях; пересчитывает значение dp
  478.         recalc_dp_heavy(g[v][i].first);
  479.     }
  480.  
  481.     if (g[v].size() > 1) {
  482.         // мержит **уже правильно посчитанные** дпшки лёгких детей.
  483.         // В текущей вершине оставляет значения as is, исправить его должен будет пересчёт на соответствующем тяжёлом пути
  484.         dps[v]->lhs = dq_kids(1, (int)g[v].size() - 1, g[v], 0);
  485.     }
  486.     else {
  487.         dps[v]->lhs = create_empty_ptr();
  488.     }
  489. }
  490.  
  491. vector<int> ans[N];
  492.  
  493. void restore_heavy_root(int, int, int);
  494.  
  495. void dq_light_restore(int l, int r, const vector<pair<int, int>>& g, int k, int msk, universal_dp_state* node) {
  496.     if (l == r) {
  497.         int v = g[l].first;
  498.         int w = g[l].second;
  499.  
  500.         universal_dp_state* cur_node = (node != nullptr ? node : init_light_leaf(v, w));
  501.  
  502.         From from = decode(cur_node->from[msk][k]);
  503.  
  504.         delete(cur_node);
  505.  
  506.         if (from.lhs_k != k) {
  507.             ans[par_v[v]].push_back(v);
  508.         }
  509.         if (from.lhs_k > 0) {
  510.             restore_heavy_root(v, from.lhs_k, from.lhs_m);
  511.         }
  512.         return;
  513.     }
  514.  
  515.     universal_dp_state* cur_root = (node != nullptr ? node : dq_kids(l, r, g, 0));
  516.  
  517.     universal_dp_state* lhs = cur_root->lhs;
  518.     universal_dp_state* rhs = cur_root->rhs;
  519.  
  520.     From from = decode(cur_root->from[msk][k]);
  521.  
  522.     delete(cur_root);
  523.  
  524.     int m = (l + r) >> 1;
  525.  
  526.     if (from.lhs_k > 0) {
  527.         dq_light_restore(l, m, g, from.lhs_k, from.lhs_m, lhs);
  528.     }
  529.  
  530.     if (from.rhs_k > 0) {
  531.         dq_light_restore(m + 1, r, g, from.rhs_k, from.rhs_m, rhs);
  532.     }
  533. }
  534.  
  535. void dq_heavy_restore(int l, int r, const vector<int>& path, int k, int msk, universal_dp_state* node) {
  536.     if (l == r) {
  537.         int v = path[l];
  538.         universal_dp_state* src = dps[v]->rhs;
  539.         From from = decode(src->from[msk][k]);
  540.         if (from.rhs_k == 1) {
  541.             ans[v].push_back(par_v[v]);
  542.         }
  543.         if (from.rhs_m == 1) {
  544.             ans[v].push_back(g[v][0].first);
  545.         }
  546.         if (from.lhs_k > 0) {
  547.             dq_light_restore(1, (int)g[v].size() - 1, g[v], from.lhs_k, from.lhs_m, dps[v]->lhs);
  548.         }
  549.         return;
  550.     }
  551.  
  552.     universal_dp_state* cur_root = (node != nullptr ? node : dq_heavy(l, r, path, 0));
  553.  
  554.     universal_dp_state* lhs = cur_root->lhs;
  555.     universal_dp_state* rhs = cur_root->rhs;
  556.  
  557.     From from = decode(cur_root->from[msk][k]);
  558.  
  559.     delete(cur_root);
  560.  
  561.     int m = (l + r) >> 1;
  562.  
  563.     if (from.lhs_k > 0) {
  564.         dq_heavy_restore(l, m, path, from.lhs_k, from.lhs_m, lhs);
  565.     }
  566.  
  567.     if (from.rhs_k > 0) {
  568.         dq_heavy_restore(m + 1, r, path, from.rhs_k, from.rhs_m, rhs);
  569.     }
  570. }
  571.  
  572. void restore_heavy_root(int v, int k, int msk) {
  573.     From from = decode(dps[v]->from[msk][k]);
  574.  
  575.     vector<int> heavy_path = { v };
  576.     int cur = v;
  577.     while (!g[cur].empty()) {
  578.         cur = g[cur][0].first;
  579.         heavy_path.push_back(cur);
  580.     }
  581.  
  582.     dq_heavy_restore(0, (int)heavy_path.size() - 1, heavy_path, from.lhs_k, from.lhs_m, heavy_root[v]);
  583. }
  584.  
  585. void work_work(int n, int k, int t) {
  586.     for (int i = 0; i < n; i++) {
  587.         cin >> w[i];
  588.     }
  589.     for (int i = 0; i < n - 1; i++) {
  590.         int u, v, s;
  591.         cin >> u >> v >> s;
  592.         //u = i + 1, v = i + 2, s = 0;
  593.         --u; --v;
  594.         g[u].emplace_back(v, s);
  595.         g[v].emplace_back(u, s);
  596.     }
  597.  
  598.     prepare(0, -1);
  599.     solve(0);
  600.     recalc_dp_heavy(0);
  601.  
  602.     cout << dps[0]->dp[0][k] << '\n';
  603.     if (t == 1) {
  604.         restore_heavy_root(0, k, 0);
  605.         for (int u = 0; u < n; u++) {
  606.             for (int i = 0; i < ans[u].size(); i += 2) {
  607.                 cout << u + 1 << ' ' << ans[u][i] + 1 << ' ' << ans[u][i + 1] + 1 << '\n';
  608.             }
  609.         }
  610.     }
  611. }
  612.  
  613. int main() {
  614.     ios::sync_with_stdio(false);
  615.     cin.tie(nullptr);
  616.  
  617.     int n, k, t;
  618.     cin >> n >> k >> t;
  619.     work_work(n, k, t);
  620. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement