Advertisement
LinKin

ML Decision tree

Feb 8th, 2014
175
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.75 KB | None | 0 0
  1. #pragma comment(linker, "/STACK:16777216")
  2.  
  3. #include<stdio.h>
  4. #include<string.h>
  5. #include<math.h>
  6. #include<stdlib.h>
  7. #include<ctype.h>
  8. #include<assert.h>
  9. #include<iostream>
  10. #include<vector>
  11. #include<stack>
  12. #include<queue>
  13. #include<set>
  14. #include<map>
  15. #include<string>
  16. #include<utility>
  17. #include<algorithm>
  18. #include<list>
  19. using namespace std;
  20.  
  21. #define CLR(a) memset(a,0,sizeof(a))
  22. #define SET(a) memset(a,-1,sizeof(a))
  23. #define pb push_back
  24. #define SZ(a) ((long)a.size())
  25. #define ALL(a) a.begin(),a.end()
  26. #define FOREACH(i, c) for( __typeof( (c).begin() ) i = (c).begin(); i != (c).end(); ++i )
  27. #define AREA2(x1,y1,x2,y2,x3,y3) ( x1*(y2-y3) + x2*(y3-y1) + x3*(y1-y2) )
  28. #define SQR(x) ((x)*(x))
  29. #define STR string
  30. #define IT iterator
  31. #define ff first
  32. #define ss second
  33. #define MP make_pair
  34. #define EPS 1e-9
  35. #define INF 1000000007
  36.  
  37. #define chk(a,k) ((bool)(a&(1<<(k))))
  38. #define set0(a,k) (a&(~(1<<(k))))
  39. #define set1(a,k) (a|(1<<(k)))
  40.  
  41. typedef long long Long;
  42. typedef vector<long> Vl;
  43. typedef vector<double> VD;
  44. typedef vector<Long> VL;
  45. typedef pair<long,long> Pll;
  46. typedef pair<Long,Long> PLL;
  47.  
  48. inline long FastMax(long x, long y) { return (((y-x)>>(32-1))&(x^y))^y; }
  49. inline long FastMin(long x, long y) { return (((y-x)>>(32-1))&(x^y))^x; }
  50.  
  51. long IR[] = { 0,-1,0,1,-1,-1,1,1 };
  52. long IC[] = { 1,0,-1,0,1,-1,-1,1 };
  53.  
  54. #define MAX_VAL 10
  55.  
  56. struct NODE{
  57.     long cls_type;
  58.     long attr_type;
  59.     vector<NODE*> child;
  60.     NODE( long cls_type ):cls_type( cls_type ){}
  61.     NODE( long attr_type, long tot_value ):attr_type( attr_type )
  62.     {
  63.         child.resize( tot_value+1 );
  64.     }
  65.     bool is_leaf( void ){ return !child.size(); }
  66. };
  67.  
  68. vector<Vl> input_data;
  69.  
  70. Pll calc_amount( vector<Vl> data )
  71. {
  72.     Pll cnt;
  73.     long i;
  74.     for( i=0;i<data.size();i++ ){
  75.         if( data[i][9]==0 ) cnt.ff++;
  76.         else cnt.ss++;
  77.     }
  78.     return cnt;
  79. }
  80.  
  81. double calc_entropy( vector<Vl> data )
  82. {
  83.     Pll cnt = calc_amount( data );
  84.     double p0 = (double)cnt.ff/( cnt.ff+cnt.ss );
  85.     double p1 = (double)cnt.ss/( cnt.ff+cnt.ss );
  86.     return -p0*log( p0 ) - p1*log( p1 );
  87. }
  88.  
  89. double calc_info_gain( vector<Vl> data, long a )
  90. {
  91.     double en = calc_entropy( data );
  92.     vector<Vl> part[MAX_VAL+1];
  93.     long i;
  94.     for( i=0;i<data.size();i++ ) part[data[i][a]].pb( data[i] );
  95.     double ig = en;
  96.     for( i=1;i<=MAX_VAL;i++ ){
  97.         ig -= part[i].size() * calc_entropy( part[i] ) / data.size();
  98.     }
  99.     return ig;
  100. }
  101.  
  102. NODE* build_ID3( vector<Vl> data, Vl attr )
  103. {
  104.     Pll cnt = calc_amount( data );
  105.     if( cnt.ff==data.size() ) return new NODE( 0 );
  106.     if( cnt.ss==data.size() ) return new NODE( 1 );
  107.     if( !attr.size() ) return new NODE( ( cnt.ff >= cnt.ss ) ? 0:1 );
  108.  
  109.     long i,w = -1;
  110.     VD info_gain;
  111.     for( i=0;i<attr.size();i++ ) info_gain.pb( calc_info_gain( data,attr[i] ) );
  112.     for( i=0;i<attr.size();i++ ) w = ( w==-1 or info_gain[w] < info_gain[i] ) ? i:w;
  113.     w = attr[w];
  114.     attr.erase( find( ALL( attr ),w ) );
  115.  
  116.  
  117.     vector<Vl> part[MAX_VAL+1];
  118.     for( i=0;i<data.size();i++ ) part[data[i][w]].pb( data[i] );
  119.  
  120.     NODE *cur_node = new NODE( w,MAX_VAL );
  121.     for( i=1;i<=10;i++ ){
  122.         if( part[i].size() ) cur_node->child[i] = build_ID3( part[i],attr );
  123.         else cur_node->child[i] = new NODE( (cnt.ff >= cnt.ss ) ? 0:1 );
  124.     }
  125.     return cur_node;
  126. }
  127.  
  128. long find_class( NODE *r, Vl attr_val )
  129. {
  130.     if( r->is_leaf()) return r->cls_type;
  131.     else{
  132.         long a = r->attr_type;
  133.         return find_class( r->child[attr_val[a]], attr_val );
  134.     }
  135. }
  136.  
  137.  
  138. void read_data( void )
  139. {
  140.     char str[107];
  141.     freopen("data.csv","r",stdin );
  142.     while( gets( str ) ){
  143.         char *p = strtok( str,", " );
  144.         vector<long> v;
  145.         while( p ){
  146.             v.pb( atol( p ) );
  147.             p = strtok( NULL,", " );
  148.         }
  149.         input_data.pb( v );
  150.     }
  151. }
  152.  
  153. void partition_data( vector<Vl> v, vector<Vl> &v1,vector<Vl> &v2, double percentage )
  154. {
  155.     long i,n = percentage*v.size();
  156.     v1.clear();
  157.     v2.clear();
  158.     for( i=0;i<n;i++ ) v1.pb( v[i] );
  159.     for( ;i<v.size();i++ ) v2.pb( v[i] );
  160. }
  161.  
  162. VD analysis( long result[2][2] )
  163. {
  164.     long tp = result[0][0];
  165.     long fp = result[0][1];
  166.     long fn = result[1][0];
  167.     long tn = result[1][1];
  168.     vector<double> v;
  169.     v.pb( 1.0*( tp + tn )/( tp + fp + fn + tn ) );
  170.     v.pb( 1.0*tp/( tp + fp ) );
  171.     v.pb( 1.0*tp/( tp + fn ) );
  172.     v.pb( 2.0*v[1]*v[2]/( v[1] + v[2] ) );
  173.     v.pb( 100*sqrt( 1.0*tp*tn )/( ( tp + fn )*( tn + fp ) ) );
  174.     return v;
  175. }
  176.  
  177. void print( VD anal_info )
  178. {
  179.     printf("Accuracy %.4lf\n",anal_info[0] );
  180.     printf("Precision %.4lf\n",anal_info[1] );
  181.     printf("Recall %.4lf\n",anal_info[2] );
  182.     printf("F-measure %.4lf\n",anal_info[3] );
  183.     printf("G mean %.4lf\n",anal_info[4] );
  184.     printf("\n");
  185. }
  186.  
  187. int main( void )
  188. {
  189.     long i,j,Icase,k = 0;
  190.  
  191.     //freopen("text1.txt","r",stdin );
  192.  
  193.     read_data();
  194.     VD anal_info;
  195.     for( i=1;i<=10;i++ ){
  196.         vector<Vl> train_data,test_data;
  197.         long result[2][2] = {0};
  198.         partition_data( input_data, train_data, test_data, 0.8 );
  199.         vector<long> attr;
  200.         for( j=0;j<9;j++ ) attr.pb( j );
  201.         NODE *root = build_ID3( train_data, attr );
  202.         for( j=0;j<test_data.size();j++ ){
  203.             long t = find_class( root, test_data[j] );
  204.             result[t][test_data[j][9]]++;
  205.         }
  206.         VD v = analysis( result );
  207.         if( i==1 ){
  208.             anal_info = v;
  209.             // print( anal_info );
  210.         }
  211.         else{
  212.             for( j=0;j<anal_info.size();j++ ){
  213.                 anal_info[j] += v[j];
  214.             }
  215.         }
  216.         random_shuffle( ALL( input_data ) );
  217.     }
  218.     for( j=0;j<anal_info.size();j++ ){
  219.         anal_info[j] /= 10;
  220.     }
  221.     print( anal_info );
  222.  
  223.     return 0;
  224. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement