Advertisement
VladNitu

Vlad MurTree impl ACC

Mar 1st, 2023
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.02 KB | None | 0 0
  1. #define MAX_DEPTH 5
  2.  
  3. //#include "Files.h"
  4.  
  5. #include <climits>
  6. #include <iostream>
  7. #include <vector>
  8. #include <bitset>
  9. #include <unordered_map>
  10. #include <chrono>
  11. #include <filesystem>
  12.  
  13. int D, label, N = 0, dim = 0, feature;
  14. std::vector<int> labels, ftrs;
  15. std::vector<std::vector<int>> features;
  16. std::string line;
  17.  
  18. inline size_t hash_vector(const std::vector<int> &vec) {
  19.     std::size_t seed = vec.size();
  20.     for (auto &i: vec) {
  21.         seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
  22.     }
  23.     return seed;
  24. }
  25.  
  26. std::unordered_map<size_t, int> dp[MAX_DEPTH];
  27.  
  28. std::bitset<500> featuresAvailable;
  29.  
  30.  
  31. inline int makeSplit(int depth, const std::vector<int> &indices, const std::vector<int> &parent_indices) {
  32.  
  33.     int ans = INT_MAX;
  34.  
  35.     size_t hash_value = hash_vector(indices);
  36.  
  37.     if (dp[depth].find(hash_value) != dp[depth].end())
  38.         return dp[depth][hash_value];
  39.  
  40.     if (depth == D - 1) { // D = 2 speed-up case
  41.  
  42.         for (int splitFeature = 0; splitFeature < dim; ++splitFeature)
  43.             if (featuresAvailable[splitFeature]) {
  44.  
  45.                 int c0_0 = 0, c0_1 = 0, c1_0 = 0, c1_1 = 0; // c1_0 -> negative labels with feature == 1
  46.  
  47.                 for (const int &idx: indices) {
  48.                     const std::vector<int> &v = features[idx];
  49.                     if (v[splitFeature] == 0) {
  50.                         if (labels[idx] == 0)
  51.                             c0_0++;
  52.                         else
  53.                             c0_1++;
  54.                     } else {
  55.                         if (labels[idx] == 0)
  56.                             c1_0++;
  57.                         else
  58.                             c1_1++;
  59.                     }
  60.                 }
  61.  
  62.                 int leftMissmatch = std::min(c0_0, c0_1);
  63.                 int rightMissmatch = std::min(c1_0, c1_1);
  64.  
  65.                 ans = std::min(ans, leftMissmatch + rightMissmatch);
  66.             }
  67.  
  68.         dp[depth][hash_value] = ans;
  69.         return ans;
  70.     } else {
  71.         for (int splitFeature = 0; splitFeature < dim; ++splitFeature)
  72.             if (featuresAvailable[splitFeature]) {
  73.  
  74.                 featuresAvailable[splitFeature] = false;
  75.  
  76.                 std::vector<int> negativeIndices{};
  77.                 std::vector<int> positiveIndices{};
  78.  
  79.                 for (const int &idx: indices) {
  80.                     const std::vector<int> &v = features[idx];
  81.                     if (v[splitFeature] == 0)
  82.                         negativeIndices.emplace_back(idx);
  83.                     else
  84.                         positiveIndices.emplace_back(idx);
  85.                 }
  86.                 int leftMissmatch = 0, rightMissmatch = 0;
  87.                 leftMissmatch = makeSplit(depth + 1, negativeIndices, indices);
  88.  
  89.  
  90.                 if (depth == 0)
  91.                     rightMissmatch = makeSplit(depth + 1, positiveIndices, indices);
  92.                 else {
  93.                     size_t parent_hash_value = hash_vector(parent_indices);
  94.                     if (dp[depth - 1].find(parent_hash_value) == dp[depth - 1].end() ||
  95.                         leftMissmatch < dp[depth - 1][parent_hash_value])
  96.                         rightMissmatch = makeSplit(depth + 1, positiveIndices, indices);
  97.                 }
  98.  
  99.                 ans = std::min(ans, leftMissmatch + rightMissmatch);
  100.  
  101.                 featuresAvailable[splitFeature] = true;
  102.  
  103.                 if (ans == 0)
  104.                     break;
  105.             }
  106.  
  107.     }
  108.  
  109.     dp[depth][hash_value] = ans;
  110.     return ans;
  111. }
  112.  
  113.  
  114. inline std::vector<int> parseInputSlowStringStream() {
  115.     std::stringstream ss;
  116.  
  117.     std::getline(std::cin, line);
  118.     ss << line;
  119.     ss >> D;
  120.     ss.clear();
  121.  
  122.     while (std::getline(std::cin, line)) {
  123.         ss << line;
  124.         ss >> label;
  125.         labels.emplace_back(label);
  126.  
  127.         ftrs = std::vector<int>{};
  128.         while (ss >> feature)
  129.             ftrs.emplace_back(feature);
  130.  
  131.         features.emplace_back(ftrs);
  132.  
  133.         ss.clear();
  134.     }
  135.  
  136.     N = features.size(); // # of datapoints
  137.     dim = features[0].size();
  138.  
  139.     std::vector<int> indices(N);
  140.     for (int i = 0; i < N; ++i)
  141.         indices[i] = i;
  142.  
  143.     for (int i = 0; i < dim; ++i)
  144.         featuresAvailable[i] = true;
  145.  
  146.     return indices;
  147. }
  148.  
  149. inline void rwFrom(const char *read_filename, const char *write_filename) {
  150.     // Speed-up reading
  151.     std::ios::sync_with_stdio(false);
  152.     std::cin.tie(NULL);
  153.  
  154.     freopen(read_filename, "r", stdin);
  155.     freopen(write_filename, "w", stdout);
  156. }
  157.  
  158. int main() {
  159.  
  160. //    std::cout << "we are currently in: " << std::filesystem::current_path() << '\n';
  161. //
  162. //    auto start = std::chrono::steady_clock::now();
  163. //
  164. //    rwFrom(Files::mushroom_in, Files::mushroom_out);
  165.     std::vector<int> indices = parseInputSlowStringStream();
  166.  
  167.     std::cout << makeSplit(0, indices, indices) << '\n'; // Solve
  168.  
  169. //    auto end = std::chrono::steady_clock::now();
  170.  
  171. //    std::cout << "Elapsed time: "
  172. //              << std::chrono::duration_cast<std::chrono::seconds>(end - start).count()
  173. //              << " sec";
  174.  
  175.     return 0;
  176. }
  177.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement