Advertisement
Guest User

DT

a guest
Nov 17th, 2019
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.71 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4.  
  5. using namespace std;
  6.  
  7. int K, N, M, H;
  8. vector<vector<int>> trainSet;
  9. vector<int> classes;
  10. int treeSize = 0;
  11.  
  12. struct tree {
  13.     int num;
  14.     bool isLeaf;
  15.     int classNum;
  16.     int feature;
  17.     double value;
  18.     struct tree *left;
  19.     struct tree *right;
  20. };
  21.  
  22. struct splitRes {
  23.     int feature{};
  24.     double value{};
  25.     vector<vector<int>> groupsIndices;
  26. };
  27.  
  28. vector<vector<int>> splitTrain(int feature, int value, const vector<int> &curIndices) {
  29.     vector<vector<int>> resultIndices(2, vector<int>());
  30.     for (int index : curIndices) {
  31.         vector<int> row = trainSet[index];
  32.         if (row[feature] < value) {
  33.             resultIndices[0].push_back(index);
  34.         } else {
  35.             resultIndices[1].push_back(index);
  36.         }
  37.     }
  38.     return resultIndices;
  39. }
  40.  
  41. vector<int> getClassNums(vector<int> curIndices) {
  42.     vector<int> classNums(curIndices.size());
  43.     for (size_t i = 0; i < curIndices.size(); i++) {
  44.         int index = curIndices[i];
  45.         classNums[i] = classes[index];
  46.     }
  47.     return classNums;
  48. }
  49.  
  50. double getGiniScore(vector<vector<int>> groupsIndices, const vector<int> &classNums) {
  51.     double giniScore = 0.;
  52.     unsigned long long rowCount = groupsIndices[0].size() + groupsIndices[1].size();
  53.  
  54.     for (auto &groupIndices : groupsIndices) {
  55.         if (groupIndices.empty()) {
  56.             continue;
  57.         }
  58.         double score = 0.;
  59.  
  60.         for (int classNum: classNums) {
  61.             double aa = 0.;
  62.             for (int groupIndex: groupIndices) {
  63.                 aa += classes[groupIndex] == classNum ? 1 : 0;
  64.             }
  65.             aa /= groupIndices.size();
  66.             score += aa * aa;
  67.         }
  68.         giniScore += (1.0 - score) * ((double) groupIndices.size() / (double) rowCount);
  69.     }
  70.  
  71.     return giniScore;
  72. }
  73.  
  74. struct splitRes *getSplit(const vector<int> &curIndices) {
  75.     vector<int> classNums = getClassNums(curIndices);
  76.     int bestFeature = M + 1;
  77.     double bestValue = 99999999999.;
  78.     double bestGiniScore = 99999999999.;
  79.     vector<vector<int>> bestGroupsIndices;
  80.  
  81.     for (int feature = 0; feature < M; feature++) {
  82.         for (int index : curIndices) {
  83.             vector<int> row = trainSet[index];
  84.             vector<vector<int>> groupsIndices = splitTrain(feature, row[feature], curIndices);
  85.             double giniScore = getGiniScore(groupsIndices, classNums);
  86.  
  87.             if (giniScore < bestGiniScore) {
  88.                 bestFeature = feature;
  89.                 bestValue = row[feature];
  90.                 bestGiniScore = giniScore;
  91.                 bestGroupsIndices = groupsIndices;
  92.             }
  93.         }
  94.     }
  95.  
  96.     auto *res = new splitRes();
  97.     res->feature = bestFeature;
  98.     res->value = bestValue;
  99.     res->groupsIndices = bestGroupsIndices;
  100.     return res;
  101. }
  102.  
  103. int getMajorClassNum(const vector<int> &indices) {
  104.     vector<int> classCount(K, 0);
  105.     for (int index: indices) {
  106.         classCount[classes[index] - 1]++;
  107.     }
  108.     int ans = -1;
  109.     int maxCount = -1;
  110.     for (int i = 0; i < classCount.size(); i++) {
  111.         if (classCount[i] > maxCount) {
  112.             ans = i + 1;
  113.             maxCount = classCount[i];
  114.         }
  115.     }
  116.     return ans;
  117. }
  118.  
  119. tree *getLeaf(int classNum) {
  120.     auto *leaf = new tree();
  121.     leaf->isLeaf = true;
  122.     leaf->classNum = classNum;
  123.     return leaf;
  124. }
  125.  
  126. tree *getNode(int feature, double value, tree *left, tree *right) {
  127.     auto *node = new tree();
  128.     node->isLeaf = false;
  129.     node->feature = feature;
  130.     node->value = value;
  131.     node->left = left;
  132.     node->right = right;
  133.     return node;
  134. }
  135.  
  136. tree *split(splitRes *node, int depth) {
  137.     vector<int> left = node->groupsIndices[0];
  138.     vector<int> right = node->groupsIndices[1];
  139.  
  140.     if (left.empty() || right.empty()) {
  141.         vector<int> nonEmptyGroup = left.empty() ? right : left;
  142.         int majorClass = getMajorClassNum(nonEmptyGroup);
  143.         tree *treeLeft = getLeaf(majorClass);
  144.         tree *treeRight = getLeaf(majorClass);
  145.         return getNode(node->feature, node->value, treeLeft, treeRight);
  146.     }
  147.  
  148.     if (depth >= H) {
  149.         tree *treeLeft = getLeaf(getMajorClassNum(left));
  150.         tree *treeRight = getLeaf(getMajorClassNum(right));
  151.         return getNode(node->feature, node->value, treeLeft, treeRight);
  152.     }
  153.  
  154.     tree *treeLeft = left.size() <= 1 ? getLeaf(getMajorClassNum(left)) : split(getSplit(left), depth + 1);
  155.     tree *treeRight = right.size() <= 1 ? getLeaf(getMajorClassNum(right)) : split(getSplit(right), depth + 1);
  156.     return getNode(node->feature, node->value, treeLeft, treeRight);
  157. }
  158.  
  159. struct tree *buildTree() {
  160.     vector<int> indices(N);
  161.     for (size_t i = 0; i < N; i++) {
  162.         indices[i] = i;
  163.     }
  164.     splitRes *res = getSplit(indices);
  165.     tree *root = split(res, 1);
  166.     return root;
  167. }
  168.  
  169. void numerateTree(tree* node) {
  170.     node->num = ++treeSize;
  171.     if (!node->isLeaf) {
  172.         numerateTree(node->left);
  173.         numerateTree(node->right);
  174.     }
  175. }
  176.  
  177. void print(tree* node) {
  178.     if (node->isLeaf) {
  179.         printf("C %d\n", node->classNum);
  180.     } else {
  181.         printf("Q %d %f %d %d\n", node->feature + 1, node->value, node->left->num, node->right->num);
  182.         print(node->left);
  183.         print(node->right);
  184.     }
  185. }
  186.  
  187. int main() {
  188.     //ios_base::sync_with_stdio(false);
  189.  
  190.     cin >> M >> K >> H >> N;
  191.     trainSet.resize(N);
  192.     classes.resize(N);
  193.  
  194.     for (size_t i = 0; i < N; i++) {
  195.         trainSet[i].resize(M);
  196.         for (size_t j = 0; j < M; j++) {
  197.             cin >> trainSet[i][j];
  198.         }
  199.         cin >> classes[i];
  200.     }
  201.  
  202.     tree *root = buildTree();
  203.     numerateTree(root);
  204.     printf("%d\n", treeSize);
  205.     print(root);
  206.  
  207.     return 0;
  208. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement