Guest User

Untitled

a guest
Jun 22nd, 2017
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.35 KB | None | 0 0
  1. #include <mlpack/core.hpp>
  2. #include <mlpack/core/optimizers/rmsprop/rmsprop.hpp>
  3. #include <mlpack/methods/ann/layer/layer.hpp>
  4. #include <mlpack/methods/ann/ffn.hpp>
  5.  
  6. using namespace mlpack;
  7. using namespace mlpack::ann;
  8. using namespace mlpack::optimization;
  9.  
  10. template<typename MatType = arma::mat>
  11. void BuildVanillaNetwork(MatType& trainData,
  12. MatType& trainLabels,
  13. MatType& testData,
  14. MatType& testLabels,
  15. const size_t outputSize,
  16. const size_t hiddenLayerSize,
  17. const size_t maxEpochs,
  18. const double classificationErrorThreshold)
  19. {
  20. /*
  21. * Construct a feed forward network with trainData.n_rows input nodes,
  22. * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
  23. * network structure looks like:
  24. *
  25. * Input Hidden Output
  26. * Layer Layer Layer
  27. * +-----+ +-----+ +-----+
  28. * | | | | | |
  29. * | +------>| +------>| |
  30. * | | +>| | +>| |
  31. * +-----+ | +--+--+ | +-----+
  32. * | |
  33. * Bias | Bias |
  34. * Layer | Layer |
  35. * +-----+ | +-----+ |
  36. * | | | | | |
  37. * | +-----+ | +-----+
  38. * | | | |
  39. * +-----+ +-----+
  40. */
  41.  
  42. FFN<NegativeLogLikelihood<> > model;
  43. model.Add<Linear<> >(trainData.n_rows, hiddenLayerSize);
  44. model.Add<SigmoidLayer<> >();
  45. model.Add<Linear<> >(hiddenLayerSize, outputSize);
  46. model.Add<LogSoftMax<> >();
  47.  
  48. RMSProp<decltype(model)> opt(model, 0.01, 0.88, 1e-8,
  49. maxEpochs * trainData.n_cols, -1);
  50.  
  51. model.Train(trainData, trainLabels, opt);
  52.  
  53. MatType predictionTemp;
  54. model.Predict(testData, predictionTemp);
  55. MatType prediction = arma::zeros<MatType>(1, predictionTemp.n_cols);
  56.  
  57. for (size_t i = 0; i < predictionTemp.n_cols; ++i)
  58. {
  59. prediction(i) = arma::as_scalar(arma::find(
  60. arma::max(predictionTemp.col(i)) == predictionTemp.col(i), 1)) + 1;
  61. }
  62.  
  63. size_t error = 0;
  64. for (size_t i = 0; i < testData.n_cols; i++)
  65. {
  66. if (int(arma::as_scalar(prediction.col(i))) ==
  67. int(arma::as_scalar(testLabels.col(i))))
  68. {
  69. error++;
  70. }
  71. }
  72.  
  73. double classificationError = 1 - double(error) / testData.n_cols;
  74. BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
  75. }
  76.  
  77.  
  78. int main()
  79. {
  80. // Load the dataset.
  81. arma::mat dataset;
  82. data::Load("thyroid_train.csv", dataset, true);
  83.  
  84. arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
  85. dataset.n_cols - 1);
  86.  
  87. arma::mat trainLabelsTemp = dataset.submat(dataset.n_rows - 3, 0,
  88. dataset.n_rows - 1, dataset.n_cols - 1);
  89. arma::mat trainLabels = arma::zeros<arma::mat>(1, trainLabelsTemp.n_cols);
  90. for (size_t i = 0; i < trainLabelsTemp.n_cols; ++i)
  91. {
  92. trainLabels(i) = arma::as_scalar(arma::find(
  93. arma::max(trainLabelsTemp.col(i)) == trainLabelsTemp.col(i), 1)) + 1;
  94. }
  95.  
  96. data::Load("thyroid_test.csv", dataset, true);
  97.  
  98. arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
  99. dataset.n_cols - 1);
  100.  
  101. arma::mat testLabelsTemp = dataset.submat(dataset.n_rows - 3, 0,
  102. dataset.n_rows - 1, dataset.n_cols - 1);
  103.  
  104. arma::mat testLabels = arma::zeros<arma::mat>(1, testLabelsTemp.n_cols);
  105. for (size_t i = 0; i < testLabels.n_cols; ++i)
  106. {
  107. testLabels(i) = arma::as_scalar(arma::find(
  108. arma::max(testLabelsTemp.col(i)) == testLabelsTemp.col(i), 1)) + 1;
  109. }
  110.  
  111. // Vanilla neural net with logistic activation function.
  112. // Because 92 percent of the patients are not hyperthyroid the neural
  113. // network must be significant better than 92%.
  114. BuildVanillaNetwork<>
  115. (trainData, trainLabels, testData, testLabels, 3, 8, 70, 0.1);
  116.  
  117. dataset.load("mnist_first250_training_4s_and_9s.arm");
  118.  
  119. // Normalize each point since these are images.
  120. for (size_t i = 0; i < dataset.n_cols; ++i)
  121. dataset.col(i) /= norm(dataset.col(i), 2);
  122.  
  123. arma::mat labels = arma::zeros(1, dataset.n_cols);
  124. labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
  125. labels += 1;
  126.  
  127. // Vanilla neural net with logistic activation function.
  128. BuildVanillaNetwork<>
  129. (dataset, labels, dataset, labels, 2, 10, 50, 0.2);
  130. }
Add Comment
Please, Sign In to add comment