Advertisement
Guest User

Untitled

a guest
Apr 8th, 2016
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.63 KB | None | 0 0
  1. /*
  2.  * To change this license header, choose License Headers in Project Properties.
  3.  * To change this template file, choose Tools | Templates
  4.  * and open the template in the editor.
  5.  */
  6.  
  7. /*
  8.  * File:   main.cpp
  9.  * Author: ozair
  10.  *
  11.  * Created on March 29, 2016, 1:33 PM
  12.  */
  13.  
  14. #include <iostream>
  15. using namespace std;
  16.  
  17. #include <mlpack/core.hpp>
  18.  
  19. #include <mlpack/methods/ann/ffn.hpp>
  20. #include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
  21. #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
  22. #include <mlpack/core/optimizers/sgd/sgd.hpp>
  23.  
  24. #include <mlpack/methods/ann/init_rules/random_init.hpp>
  25.  
  26. #include <mlpack/methods/ann/layer/bias_layer.hpp>
  27. #include <mlpack/methods/ann/layer/linear_layer.hpp>
  28. #include <mlpack/methods/ann/layer/base_layer.hpp>
  29. #include <mlpack/methods/ann/layer/multiclass_classification_layer.hpp>
  30. #include <mlpack/methods/ann/layer/dropout_layer.hpp>
  31. #include <mlpack/methods/ann/layer/dropconnect_layer.hpp>
  32.  
  33. using namespace mlpack;
  34. using namespace mlpack::ann;
  35. using namespace mlpack::optimization;
  36.  
  37. struct ConfusionMatrix {
  38.     double accuracy;
  39.     double precision;
  40.     double recall;
  41. };
  42.  
  43. void printMatAttr(char const* MatName, const arma::mat &data) {
  44.     cout << "Attributes of " << MatName << '\n';
  45.     cout << "# of rows: " << data.n_rows << '\n';
  46.     cout << "# of cols: " << data.n_cols << '\n';
  47.     cout << "# of elements: " << data.n_elem << "\n\n";
  48. }
  49.  
  50. ConfusionMatrix getClassificationError(const arma::mat &prediction, const arma::mat &labels, int positiveClass, double confidenceThreshold) {
  51.     ConfusionMatrix attr;
  52.     size_t error = 0, tp = 0;
  53.     arma::mat predictionClass(prediction.n_rows, prediction.n_cols);
  54.     for (size_t i = 0; i < labels.n_cols; i++) {
  55.         predictionClass(0, i) = prediction(0, i) > confidenceThreshold ? 1 : 0;
  56.         if (arma::sum(arma::abs(predictionClass.col(i) - labels.col(i))) != 0)    
  57.             error++;
  58.         else if (arma::sum(predictionClass.col(i) == positiveClass) == 1 && arma::sum(labels.col(i) == positiveClass) == 1)
  59.             tp++;
  60.     }
  61.        
  62.     attr.accuracy    = ((labels.n_cols - error) / (double) labels.n_cols) * 100;
  63.     attr.precision   = tp / (double) (arma::sum(predictionClass.row(0) == positiveClass));
  64.     attr.recall      = tp / (double) (arma::sum(labels.row(0) == positiveClass));
  65.              
  66.     return attr;
  67. }
  68.  
  69. int main(int argc, char** argv) {    
  70.     arma::mat trainingData, trainingLabels, testData, testLabels, testprediction, trainprediction;
  71.     size_t HiddenLayerSize, InputLayerSize, NumOfSamples, NumOfTrainingSamples, NumOfTestSamples, NumOfFeatures;
  72.     double trainingRatio = 0.75;
  73.    
  74.     trainingData.load("snpbinary.csv");
  75.     NumOfSamples = trainingData.n_rows;
  76.     NumOfFeatures = trainingData.n_cols - 1;
  77.    
  78.     arma::arma_rng::set_seed_random();
  79.     trainingData = arma::shuffle(trainingData);
  80.    
  81.     cout << "Number of samples: " << NumOfSamples << "\n\n";
  82.    
  83.     trainingLabels = trainingData.col(NumOfFeatures);
  84.     trainingData.shed_col(NumOfFeatures);
  85.    
  86.     NumOfTrainingSamples = NumOfSamples * trainingRatio;
  87.     NumOfTestSamples = NumOfSamples - NumOfTrainingSamples;
  88.    
  89.     testData = trainingData.rows(0, NumOfTestSamples - 1);
  90.     testLabels = trainingLabels.rows(0, NumOfTestSamples - 1);
  91.    
  92.     trainingData.shed_rows(0, NumOfTestSamples - 1);
  93.     trainingLabels.shed_rows(0, NumOfTestSamples - 1);
  94.        
  95.     printMatAttr("TrainingData", trainingData);
  96.     printMatAttr("TrainingLabels", trainingLabels);
  97.     printMatAttr("TestData", testData);
  98.     printMatAttr("TestLabels", testLabels);
  99.    
  100.     trainingLabels = trans(trainingLabels);
  101.     trainingData = trans(trainingData);
  102.     testData = trans(testData);
  103.     testLabels = trans(testLabels);
  104.    
  105.     InputLayerSize = NumOfFeatures;
  106.     HiddenLayerSize = NumOfFeatures + (NumOfFeatures / 2);        
  107.    
  108.     LinearLayer<> inputLayer(NumOfFeatures, HiddenLayerSize);
  109.     BiasLayer<> inputBiasLayer(HiddenLayerSize);
  110.     BaseLayer<LogisticFunction> inputBaseLayer;
  111.    
  112.     LinearLayer<> hiddenLayer1(HiddenLayerSize, HiddenLayerSize);
  113.     BiasLayer<> hiddenBiasLayer1(HiddenLayerSize);
  114.     BaseLayer<LogisticFunction> hiddenBaseLayer1;
  115.    
  116.     LinearLayer<> hiddenLayer2(HiddenLayerSize, HiddenLayerSize);
  117.     BiasLayer<> hiddenBiasLayer2(HiddenLayerSize);
  118.     BaseLayer<LogisticFunction> hiddenBaseLayer2;
  119.    
  120.     LinearLayer<> hiddenLayer3(HiddenLayerSize, HiddenLayerSize);
  121.     BiasLayer<> hiddenBiasLayer3(HiddenLayerSize);
  122.     BaseLayer<LogisticFunction> hiddenBaseLayer3;
  123.    
  124.     LinearLayer<> hiddenLayer4(HiddenLayerSize, HiddenLayerSize);
  125.     BiasLayer<> hiddenBiasLayer4(HiddenLayerSize);
  126.     BaseLayer<LogisticFunction> hiddenBaseLayer4;
  127.    
  128.     LinearLayer<> outputLayer(HiddenLayerSize, 1);
  129.     BiasLayer<> outputBiasLayer(1);
  130.     //DropConnectLayer<decltype(outputLayer)> dropConnectLayer0(outputLayer);
  131.     BaseLayer<LogisticFunction> outputBaseLayer;
  132.    
  133.     DropoutLayer<> dropoutLayer0;
  134.  
  135.     MulticlassClassificationLayer classOutputLayer;
  136.    
  137.     auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer,
  138.                             //hiddenLayer1, hiddenBiasLayer1, hiddenBaseLayer1,
  139.                             //hiddenLayer2, hiddenBiasLayer2, hiddenBaseLayer2,
  140.                             //hiddenLayer3, hiddenBiasLayer3, hiddenBaseLayer3,
  141.                             hiddenLayer4, hiddenBiasLayer4, hiddenBaseLayer4,
  142.                             outputLayer, outputBiasLayer, outputBaseLayer);
  143.    
  144.     FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
  145.       MeanSquaredErrorFunction> NN(modules, classOutputLayer);
  146.    
  147.     SGD<decltype(NN)> opt(NN, 0.01, 100000, 1e-5, true);
  148.    
  149.     cout << "Begin Training\n";
  150.     NN.Train(trainingData, trainingLabels, opt);
  151.    
  152.     NN.Predict(trainingData, trainprediction);
  153.     NN.Predict(testData, testprediction);
  154.    
  155.     arma::Col<double> confidenceList = arma::linspace<arma::Col<double>>(0.05, 0.95, 36);
  156.     ConfusionMatrix attr;
  157.    
  158.     for (size_t t = 0; t < confidenceList.n_elem; t++) {
  159.         cout << confidenceList(t) << ", ";      
  160.         attr = getClassificationError(trainprediction, trainingLabels, 1, confidenceList(t));
  161.         cout << attr.accuracy << ", ";
  162.         cout << attr.precision << ", ";
  163.         cout << attr.recall << " ... ";
  164.        
  165.         cout << confidenceList(t) << ", ";
  166.         attr = getClassificationError(testprediction, testLabels, 1, confidenceList(t));
  167.         cout << attr.accuracy << ", ";
  168.         cout << attr.precision << ", ";
  169.         cout << attr.recall << "\n";
  170.     }
  171.      
  172.     return 0;
  173. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement