Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include<bits/stdc++.h>
- #define int long double
- using namespace std;
- const signed D=10,N=1000;
- vector<vector<int>> x;
- struct node{
- node *l,*r;
- signed k;
- int val,theta;
- node(node *l,node *r,signed k,int theta):l(l),r(r),k(k),theta(theta){}
- node(int val):l(nullptr),r(nullptr),k(-1.0),theta(0.0),val(val){}
- };
- signed cp;
- bool cmp(vector<int> a,vector<int> b){
- return a[cp]<b[cp];
- }
- node *build(vector<vector<int>> x){
- //cout<<x.size()<<'\n';
- bool flag=true;
- for(signed i=1;i<x.size();i++){
- if(x[i][D]!=x[0][D]){
- flag=false;
- break;
- }
- }
- if(flag){
- return new node(x[0][D]);
- }
- signed k=0;
- int theta=0,err=1e9;
- for(signed j=0;j<D;j++){
- cp=j;
- sort(x.begin(),x.end(),cmp);
- signed neg_cnt=0,cnt=0;
- for(signed i=0;i<x.size();i++){
- if(x[i][D]<0){
- neg_cnt++;
- }
- }
- for(signed i=0;i<(signed)(x.size())-1;i++){
- if(x[i][D]<0){
- cnt++;
- neg_cnt--;
- assert(neg_cnt>=0);
- }
- if(x[i][j]==x[i+1][j]){
- continue;
- }
- assert(x[i][j]<x[i+1][j]);
- int e2=min(cnt+(signed)(x.size())-i-neg_cnt,i-cnt+neg_cnt);
- //cout<<e2<<' '<<err<<'\n';
- if(e2<err){
- err=e2;
- k=j;
- theta=(x[i][j]+x[i+1][j])/2;
- //assert(x[i][j]!=x[i+1][j]);
- }
- }
- }
- cp=k;
- sort(x.begin(),x.end(),cmp);
- vector<vector<int>> a,b;
- for(auto &i:x){
- if(i[k]>theta){
- a.push_back(i);
- }
- else{
- b.push_back(i);
- }
- }
- cout<<a.size()<<' '<<b.size()<<'\n';
- return new node(build(a),build(b),k,theta);
- }
- int predict(node *a,vector<int> x){
- if(a->l){
- return (x[a->k]>(a->theta))?(predict(a->l,x)):(predict(a->r,x));
- }
- else{
- return a->val;
- }
- }
- signed main(){
- x.resize(N);
- for(signed i=0;i<N;i++){
- cout<<i<<endl;
- x[i].resize(D+1);
- for(signed j=0;j<D+1;j++){
- cin>>x[i][j];
- }
- }
- cout<<"read train done\n";
- node *root=build(x);
- cout<<"build done\n";
- int cnt=0;
- for(signed i=0;i<N;i++){
- vector<int> tmp;
- tmp.resize(D);
- for(signed j=0;j<D;j++){
- cin>>tmp[j];
- }
- int t;
- cin>>t;
- if(t!=predict(root,tmp)){
- cnt++;
- }
- }
- cout<<cnt/N<<'\n';
- }
RAW Paste Data