Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <cmath>
- using namespace std;
- int K, N, M, H;
- vector<vector<int>> trainSet;
- vector<int> classes;
- int treeSize = 0;
- struct tree {
- int num;
- bool isLeaf;
- int classNum;
- int feature;
- double value;
- struct tree *left;
- struct tree *right;
- };
- struct splitRes {
- int feature{};
- double value{};
- vector<vector<int>> groupsIndices;
- };
- vector<vector<int>> splitTrain(int feature, int value, const vector<int> &curIndices) {
- vector<vector<int>> resultIndices(2, vector<int>());
- for (int index : curIndices) {
- vector<int> row = trainSet[index];
- if (row[feature] < value) {
- resultIndices[0].push_back(index);
- } else {
- resultIndices[1].push_back(index);
- }
- }
- return resultIndices;
- }
- vector<int> getClassNums(vector<int> curIndices) {
- vector<int> classNums(curIndices.size());
- for (size_t i = 0; i < curIndices.size(); i++) {
- int index = curIndices[i];
- classNums[i] = classes[index];
- }
- return classNums;
- }
- double getGiniScore(vector<vector<int>> groupsIndices, const vector<int> &classNums) {
- double giniScore = 0.;
- unsigned long long rowCount = groupsIndices[0].size() + groupsIndices[1].size();
- for (auto &groupIndices : groupsIndices) {
- if (groupIndices.empty()) {
- continue;
- }
- double score = 0.;
- for (int classNum: classNums) {
- double aa = 0.;
- for (int groupIndex: groupIndices) {
- aa += classes[groupIndex] == classNum ? 1 : 0;
- }
- aa /= groupIndices.size();
- score += aa * aa;
- }
- giniScore += (1.0 - score) * ((double) groupIndices.size() / (double) rowCount);
- }
- return giniScore;
- }
- struct splitRes *getSplit(const vector<int> &curIndices) {
- vector<int> classNums = getClassNums(curIndices);
- int bestFeature = M + 1;
- double bestValue = 99999999999.;
- double bestGiniScore = 99999999999.;
- vector<vector<int>> bestGroupsIndices;
- for (int feature = 0; feature < M; feature++) {
- for (int index : curIndices) {
- vector<int> row = trainSet[index];
- vector<vector<int>> groupsIndices = splitTrain(feature, row[feature], curIndices);
- double giniScore = getGiniScore(groupsIndices, classNums);
- if (giniScore < bestGiniScore) {
- bestFeature = feature;
- bestValue = row[feature];
- bestGiniScore = giniScore;
- bestGroupsIndices = groupsIndices;
- }
- }
- }
- auto *res = new splitRes();
- res->feature = bestFeature;
- res->value = bestValue;
- res->groupsIndices = bestGroupsIndices;
- return res;
- }
- int getMajorClassNum(const vector<int> &indices) {
- vector<int> classCount(K, 0);
- for (int index: indices) {
- classCount[classes[index] - 1]++;
- }
- int ans = -1;
- int maxCount = -1;
- for (int i = 0; i < classCount.size(); i++) {
- if (classCount[i] > maxCount) {
- ans = i + 1;
- maxCount = classCount[i];
- }
- }
- return ans;
- }
- tree *getLeaf(int classNum) {
- auto *leaf = new tree();
- leaf->isLeaf = true;
- leaf->classNum = classNum;
- return leaf;
- }
- tree *getNode(int feature, double value, tree *left, tree *right) {
- auto *node = new tree();
- node->isLeaf = false;
- node->feature = feature;
- node->value = value;
- node->left = left;
- node->right = right;
- return node;
- }
- tree *split(splitRes *node, int depth) {
- vector<int> left = node->groupsIndices[0];
- vector<int> right = node->groupsIndices[1];
- if (left.empty() || right.empty()) {
- vector<int> nonEmptyGroup = left.empty() ? right : left;
- int majorClass = getMajorClassNum(nonEmptyGroup);
- tree *treeLeft = getLeaf(majorClass);
- tree *treeRight = getLeaf(majorClass);
- return getNode(node->feature, node->value, treeLeft, treeRight);
- }
- if (depth >= H) {
- tree *treeLeft = getLeaf(getMajorClassNum(left));
- tree *treeRight = getLeaf(getMajorClassNum(right));
- return getNode(node->feature, node->value, treeLeft, treeRight);
- }
- tree *treeLeft = left.size() <= 1 ? getLeaf(getMajorClassNum(left)) : split(getSplit(left), depth + 1);
- tree *treeRight = right.size() <= 1 ? getLeaf(getMajorClassNum(right)) : split(getSplit(right), depth + 1);
- return getNode(node->feature, node->value, treeLeft, treeRight);
- }
- struct tree *buildTree() {
- vector<int> indices(N);
- for (size_t i = 0; i < N; i++) {
- indices[i] = i;
- }
- splitRes *res = getSplit(indices);
- tree *root = split(res, 1);
- return root;
- }
- void numerateTree(tree* node) {
- node->num = ++treeSize;
- if (!node->isLeaf) {
- numerateTree(node->left);
- numerateTree(node->right);
- }
- }
- void print(tree* node) {
- if (node->isLeaf) {
- printf("C %d\n", node->classNum);
- } else {
- printf("Q %d %f %d %d\n", node->feature + 1, node->value, node->left->num, node->right->num);
- print(node->left);
- print(node->right);
- }
- }
- int main() {
- //ios_base::sync_with_stdio(false);
- cin >> M >> K >> H >> N;
- trainSet.resize(N);
- classes.resize(N);
- for (size_t i = 0; i < N; i++) {
- trainSet[i].resize(M);
- for (size_t j = 0; j < M; j++) {
- cin >> trainSet[i][j];
- }
- cin >> classes[i];
- }
- tree *root = buildTree();
- numerateTree(root);
- printf("%d\n", treeSize);
- print(root);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement