double112233

HTML HW6 p14

Jan 15th, 2021
650
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include<bits/stdc++.h>
  2. #define int long double
  3. using namespace std;
  4. const signed D=10,N=1000;
  5. vector<vector<int>> x;
  6. struct node{
  7.     node *l,*r;
  8.     signed k;
  9.     int val,theta;
  10.     node(node *l,node *r,signed k,int theta):l(l),r(r),k(k),theta(theta){}
  11.     node(int val):l(nullptr),r(nullptr),k(-1.0),theta(0.0),val(val){}
  12. };
  13. signed cp;
  14. bool cmp(vector<int> a,vector<int> b){
  15.     return a[cp]<b[cp];
  16. }
  17. node *build(vector<vector<int>> x){
  18.     //cout<<x.size()<<'\n';
  19.     bool flag=true;
  20.     for(signed i=1;i<x.size();i++){
  21.         if(x[i][D]!=x[0][D]){
  22.             flag=false;
  23.             break;
  24.         }
  25.     }
  26.     if(flag){
  27.         return new node(x[0][D]);
  28.     }
  29.     signed k=0;
  30.     int theta=0,err=1e9;
  31.     for(signed j=0;j<D;j++){
  32.         cp=j;
  33.         sort(x.begin(),x.end(),cmp);
  34.         signed neg_cnt=0,cnt=0;
  35.         for(signed i=0;i<x.size();i++){
  36.             if(x[i][D]<0){
  37.                 neg_cnt++;
  38.             }
  39.         }
  40.         for(signed i=0;i<(signed)(x.size())-1;i++){
  41.             if(x[i][D]<0){
  42.                 cnt++;
  43.                 neg_cnt--;
  44.                 assert(neg_cnt>=0);
  45.             }
  46.             if(x[i][j]==x[i+1][j]){
  47.                 continue;
  48.             }
  49.             assert(x[i][j]<x[i+1][j]);
  50.             int e2=min(cnt+(signed)(x.size())-i-neg_cnt,i-cnt+neg_cnt);
  51.             //cout<<e2<<' '<<err<<'\n';
  52.             if(e2<err){
  53.                 err=e2;
  54.                 k=j;
  55.                 theta=(x[i][j]+x[i+1][j])/2;
  56.                 //assert(x[i][j]!=x[i+1][j]);
  57.             }
  58.         }
  59.     }
  60.     cp=k;
  61.     sort(x.begin(),x.end(),cmp);
  62.     vector<vector<int>> a,b;
  63.     for(auto &i:x){
  64.         if(i[k]>theta){
  65.             a.push_back(i);
  66.         }
  67.         else{
  68.             b.push_back(i);
  69.         }
  70.     }
  71.     cout<<a.size()<<' '<<b.size()<<'\n';
  72.     return new node(build(a),build(b),k,theta);
  73. }
  74. int predict(node *a,vector<int> x){
  75.     if(a->l){
  76.         return (x[a->k]>(a->theta))?(predict(a->l,x)):(predict(a->r,x));
  77.     }
  78.     else{
  79.         return a->val;
  80.     }
  81. }
  82. signed main(){
  83.     x.resize(N);
  84.     for(signed i=0;i<N;i++){
  85.         cout<<i<<endl;
  86.         x[i].resize(D+1);
  87.         for(signed j=0;j<D+1;j++){
  88.             cin>>x[i][j];
  89.         }
  90.     }
  91.     cout<<"read train done\n";
  92.     node *root=build(x);
  93.     cout<<"build done\n";
  94.     int cnt=0;
  95.     for(signed i=0;i<N;i++){
  96.         vector<int> tmp;
  97.         tmp.resize(D);
  98.         for(signed j=0;j<D;j++){
  99.             cin>>tmp[j];
  100.         }
  101.         int t;
  102.         cin>>t;
  103.         if(t!=predict(root,tmp)){
  104.             cnt++;
  105.         }
  106.     }
  107.     cout<<cnt/N<<'\n';
  108. }
RAW Paste Data