shopnobaj

ML_offline_decision_tree_learning

Mar 26th, 2015
226
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.79 KB | None | 0 0
  1. #include<iostream>
  2. #include<fstream>
  3. #include<vector>
  4. #include<cmath>
  5. #include<sstream>
  6. #include<ctime>
  7. #include<cstdlib>
  8. #define pb push_back
  9. using namespace std;
  10.  
  11. struct example
  12. {
  13.     int vals[10];
  14.     example(int arr[])
  15.     {
  16.         for(int i=0;i<10;i++) vals[i]=arr[i];
  17.     }
  18. };
  19.  
  20. struct node
  21. {
  22.     int type,attr,val;
  23.     int child[11];
  24. };
  25.  
  26. class decision_tree
  27. {
  28.     int node_cnt;
  29.     vector<node> tree;
  30. public:
  31.     decision_tree()
  32.     {
  33.         node_cnt=0;
  34.     }
  35.     void make_tree(vector<example> &examples)
  36.     {
  37.         ID3(examples,0);
  38.     }
  39.     int get_result(example ex)
  40.     {
  41.         return match(ex,0);
  42.     }
  43. private:
  44.     int match(example ex,int node)
  45.     {
  46.         if(tree[node].type==0) return tree[node].val;
  47.         return match(ex,tree[node].child[ex.vals[tree[node].attr]]);
  48.     }
  49.     int ID3(vector<example> &examples,int attr_mask)
  50.     {
  51.         //cout<<examples.size()<<" ";
  52.         int pos=0,neg=0;
  53.         for(int i=0;i<examples.size();i++)
  54.         {
  55.             if(examples[i].vals[9]==1) pos++;
  56.             else neg++;
  57.         }
  58.         //cout<<pos<<" "<<neg<<endl;
  59.         int node_no=node_cnt;
  60.         node_cnt++;
  61.         node nw_node;
  62.         if(pos==examples.size())
  63.         {
  64.             nw_node.type=0;
  65.             nw_node.val=1;
  66.             tree.push_back(nw_node);
  67.             return node_no;
  68.         }
  69.         else if(neg==examples.size())
  70.         {
  71.             nw_node.type=0;
  72.             nw_node.val=0;
  73.             tree.push_back(nw_node);
  74.             return node_no;
  75.         }
  76.         if(attr_mask==(1<<9)-1)
  77.         {
  78.             nw_node.type=0;
  79.             nw_node.val=pos>neg?1:0;
  80.             tree.push_back(nw_node);
  81.             return node_no;
  82.         }
  83.         //cout<<"ok\n";
  84.         int attr=get_best_attr(examples,attr_mask);
  85.         //cout<<examples.size()<<" "<<attr<<endl;
  86.         nw_node.type=1;
  87.         nw_node.attr=attr;
  88.         tree.push_back(nw_node);
  89.         vector<example> tmp[11];
  90.         for(int i=0;i<examples.size();i++)
  91.         {
  92.             tmp[examples[i].vals[attr]].push_back(examples[i]);
  93.         }
  94.         for(int i=1;i<=10;i++)
  95.         {
  96.             if(tmp[i].empty())
  97.             {
  98.                 node nd;
  99.                 nd.type=0;
  100.                 nd.val=pos>neg?1:0;
  101.                 tree.push_back(nd);
  102.                 tree[node_no].child[i]=node_cnt;
  103.                 node_cnt++;
  104.             }
  105.             else
  106.             {
  107.                 int n=ID3(tmp[i],attr_mask | (1<<attr));
  108.                 tree[node_no].child[i]=n;
  109.             }
  110.         }
  111.         return node_no;
  112.     }
  113.     int get_best_attr(vector<example> &examples,int attr_mask)
  114.     {
  115.         //cout<<examples.size()<<" "<<attr_mask<<" ";
  116.         double mn=5.00; //INF
  117.         int attr;
  118.         for(int i=0;i<9;i++)
  119.         {
  120.             if((attr_mask & (1<<i))==0)
  121.             {
  122.                 //cout<<"ok ";
  123.                 double tmp=expected_entropy(examples,i);
  124.                 //cout<<tmp<<" "<<endl;
  125.                 if(tmp<mn) mn=tmp,attr=i;
  126.             }
  127.         }
  128.         return attr;
  129.     }
  130.     double expected_entropy(vector<example> &examples,int attr)
  131.     {
  132.         //cout<<"inside exp_entro "<<attr<<" ";
  133.         int cnt[2][11]={0};
  134.         for(int i=0;i<examples.size();i++)
  135.         {
  136.             cnt[examples[i].vals[9]][examples[i].vals[attr]]++;
  137.         }
  138.         int total=examples.size();
  139.         //cout<<"total "<<total<<"\n";
  140.         double ret=0;
  141.         for(int i=1;i<=10;i++)
  142.         {
  143.             //cout<<cnt[0][i]<<" "<<cnt[1][i]<<" ";
  144.             ret+=double(cnt[0][i]+cnt[1][i])/total;ret*=entropy(cnt[0][i],cnt[1][i]);
  145.             //cout<<ret<<endl;
  146.         }
  147.         return ret;
  148.     }
  149.     double entropy(int a,int b)
  150.     {
  151.         //cout<<"inside entro "<<a<<" "<<b<<" ";
  152.         if(a*b==0) return 0;
  153.         double x=(double) a/(a+b);
  154.         //cout<<x<<" ";
  155.         //if(x==1)
  156.         return -x*log(x)-(1-x)*log(1-x);
  157.     }
  158. };
  159.  
  160. int main()
  161. {
  162.     vector<example> examples;
  163.     ifstream fin("data.csv");
  164.     ofstream fout("results.txt");
  165.     string str;
  166.  
  167.     /** Get examples from file **/
  168.     while(fin>>str)
  169.     {
  170.         for(int i=0;i<str.size();i++) if(str[i]==',') str[i]=' ';
  171.         stringstream ss;
  172.         ss<<str;
  173.         int arr[10];
  174.         for(int j=0;j<10;j++) ss>>arr[j];
  175.         example ex(arr);
  176.         examples.push_back(ex);
  177.     }
  178.  
  179.     /**Iterate 100 times **/
  180.     int no_itr=10,tst=1;
  181.     double av_ac,av_pr,av_rc,av_fm,av_gm;
  182.     av_ac=av_pr=av_rc=av_fm=av_gm=0;
  183.     while(tst++<=no_itr)
  184.     {
  185.         vector<example> train,test;
  186.         srand(time(NULL)+rand()%187);
  187.         for(int i=0;i<examples.size();i++)
  188.         {
  189.             if(rand()%10<8) train.pb(examples[i]);
  190.             else test.pb(examples[i]);
  191.         }
  192.         //cout<<train.size()<<" "<<test.size()<<endl;
  193.         vector<example> tmp;
  194.         for(int i=0;i<6;i++) tmp.pb(examples[i]);
  195.         decision_tree tree;
  196.         tree.make_tree(train);
  197.         //cout<<tree.get_result(test[0])<<endl;
  198.         int pos,neg,tp,tn,fp,fn;
  199.         pos=neg=tp=tn=fp=fn=0;
  200.         for(int i=0;i<test.size();i++)
  201.         {
  202.             int tv=test[i].vals[9];
  203.             if(tv) pos++;
  204.             else neg++;
  205.             int mv=tree.get_result(test[i]);
  206.             if(mv==tv)
  207.             {
  208.                 if(mv) tp++;
  209.                 else tn++;
  210.             }
  211.             else
  212.             {
  213.                 if(mv) fn++;
  214.                 else fp++;
  215.             }
  216.         }
  217.         //cout<<pos<<" "<<neg<<" "<<tp<<" "<<tn<<" "<<fp<<" "<<fn<<endl;
  218.         fout<<"Test no: "<<tst-1<<"\n------------------------\n";
  219.         fout<<"positive examples: "<<pos<<endl;
  220.         fout<<"negative examples: "<<neg<<endl;
  221.         fout<<"positive detected correctly: "<<tp<<endl;
  222.         fout<<"negative detected correctly: "<<tn<<endl;
  223.         fout<<"false positive: "<<fp<<endl;
  224.         fout<<"false negative: "<<fn<<endl;
  225.         fout<<"accuracy: ";
  226.         double ac=(double)(tp+tn)/(pos+neg);av_ac+=ac;
  227.         fout<<ac<<endl;
  228.         fout<<"precision: ";
  229.         double pr=(double)tp/(tp+fp);av_pr+=pr;
  230.         fout<<pr<<endl;
  231.         fout<<"recall: ";
  232.         double rc=(double)tp/pos;av_rc+=rc;
  233.         fout<<rc<<endl;
  234.         fout<<"f-measure: ";
  235.         double fm=(double)2*pr*rc/(pr+rc);av_fm+=fm;
  236.         fout<<fm<<endl;
  237.         fout<<"g-mean: ";
  238.         double gm=sqrt(pr*rc);av_gm+=gm;
  239.         fout<<gm<<endl;
  240.         fout<<endl;
  241.     }
  242.     fout<<"average performance\n-------------------\n";
  243.     fout<<"average accuracy: "<<av_ac/no_itr<<endl;
  244.     fout<<"average precision: "<<av_pr/no_itr<<endl;
  245.     fout<<"average recall: "<<av_rc/no_itr<<endl;
  246.     fout<<"average f-measure: "<<av_fm/no_itr<<endl;
  247.     fout<<"average g-mean: "<<av_gm/no_itr<<endl;
  248.     return 0;
  249. }
Add Comment
Please, Sign In to add comment