Guest User

Neural_network_using_mlpack.cpp

a guest
Apr 8th, 2016
45
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.38 KB | None | 0 0
  1. /*
  2. Compile with
  3. g++ -std=c++11 Neural_network_using_mlpack.cpp -l mlpack -l armadillo -l boost_serialization -l boost_program_options
  4. where Neural_network_using_mlpack.cpp is the name of this cpp file
  5. */
  6.  
  7. #include <mlpack/core.hpp>
  8.  
  9. #include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
  10. #include <mlpack/methods/ann/activation_functions/tanh_function.hpp>
  11.  
  12. #include <mlpack/methods/ann/init_rules/random_init.hpp>
  13.  
  14. #include <mlpack/methods/ann/layer/bias_layer.hpp>
  15. #include <mlpack/methods/ann/layer/linear_layer.hpp>
  16. #include <mlpack/methods/ann/layer/base_layer.hpp>
  17. #include <mlpack/methods/ann/layer/dropout_layer.hpp>
  18. #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
  19.  
  20. #include <mlpack/methods/ann/ffn.hpp>
  21. #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
  22. #include <mlpack/core/optimizers/rmsprop/rmsprop.hpp>
  23.  
  24. using namespace mlpack;
  25. using namespace mlpack::ann;
  26. using namespace mlpack::optimization;
  27.  
  28. template<
  29.     typename PerformanceFunction,
  30.     typename OutputLayerType,
  31.     typename PerformanceFunctionType,
  32.     typename MatType = arma::mat
  33. >
  34. void BuildNetwork(MatType& trainData,
  35.                          MatType& trainLabels,
  36.                          MatType& testData,
  37.                          MatType& testLabels,
  38.                          const size_t hiddenLayerSize,
  39.                          const size_t maxEpochs,
  40.                          const double classificationErrorThreshold)
  41. {
  42.   LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
  43.   BiasLayer<> inputBiasLayer(hiddenLayerSize);
  44.   BaseLayer<PerformanceFunction> inputBaseLayer;
  45.  
  46.   LinearLayer<> hiddenLayer1(hiddenLayerSize, trainLabels.n_rows);
  47.   BiasLayer<> hiddenBiasLayer1(trainLabels.n_rows);
  48.   BaseLayer<PerformanceFunction> outputLayer;
  49.  
  50.   OutputLayerType classOutputLayer;
  51.  
  52.   auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer,
  53.                           hiddenLayer1, hiddenBiasLayer1, outputLayer);
  54.  
  55.   FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
  56.       PerformanceFunctionType> net(modules, classOutputLayer);
  57.  
  58.   RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
  59.       maxEpochs * trainData.n_cols, 1e-18);
  60.  
  61.  std::cout<<"Success"<<std::endl;
  62.  
  63.   net.Train(trainData, trainLabels, opt);
  64.  
  65.   MatType prediction;
  66.   net.Predict(testData, prediction);
  67.  
  68.   size_t error = 0;
  69.   for (size_t i = 0; i < testData.n_cols; i++)
  70.   {
  71.     if (arma::sum(arma::sum(
  72.         arma::abs(prediction.col(i) - testLabels.col(i)))) == 0)
  73.     {
  74.       error++;
  75.     }
  76.   }
  77.  
  78.   double classificationError = 1 - double(error) / testData.n_cols;
  79.  
  80. }
  81.  
  82. int main()
  83. {
  84.   arma::mat dataset;
  85.   data::Load("thyroid_train.csv", dataset, true);
  86.  
  87.   arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
  88.       dataset.n_cols - 1);
  89.   arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
  90.       dataset.n_rows - 1, dataset.n_cols - 1);
  91.  
  92.   data::Load("thyroid_test.csv", dataset, true);
  93.  
  94.   arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
  95.       dataset.n_cols - 1);
  96.   arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
  97.       dataset.n_rows - 1, dataset.n_cols - 1);
  98.  
  99.   BuildNetwork<LogisticFunction,
  100.                       BinaryClassificationLayer,
  101.                       MeanSquaredErrorFunction>
  102.       (trainData, trainLabels, testData, testLabels, 8, 200, 0.1);
  103. return 0;
  104. }
Add Comment
Please, Sign In to add comment