Advertisement
Guest User

mlpackRNNTest

a guest
Mar 29th, 2016
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.47 KB | None | 0 0
  1. int main(int argc, char** argv) {
  2.    
  3.     arma::mat data, labels, prediction;
  4.     size_t HiddenLayerSize, InputLayerSize, error;
  5.     double ClassificationError;
  6.        
  7.     data.load("snpdata.csv");
  8.     labels = data.col(0);
  9.     data.shed_col(0);
  10.    
  11.     InputLayerSize = data.n_rows;
  12.     HiddenLayerSize = data.n_rows + (data.n_rows / 2);
  13.        
  14.     LinearLayer<> linearLayer0(InputLayerSize, HiddenLayerSize);
  15.     RecurrentLayer<> recurrentLayer0(HiddenLayerSize);
  16.     BaseLayer<LogisticFunction> inputBaseLayer;
  17.    
  18.     LinearLayer<> hiddenLayer(HiddenLayerSize, 1);
  19.     BaseLayer<LogisticFunction> hiddenBaseLayer;
  20.    
  21.     BinaryClassificationLayer classOutputLayer;
  22.    
  23.     auto modules = std::tie(linearLayer0, recurrentLayer0, inputBaseLayer,
  24.                             hiddenLayer, hiddenBaseLayer);
  25.    
  26.     RNN<decltype(modules), BinaryClassificationLayer, RandomInitialization,
  27.             MeanSquaredErrorFunction> NN(modules, classOutputLayer);
  28.    
  29.     SGD<decltype(NN)> opt(NN, 0.01, 100000, 1e-5, true);
  30.    
  31.     NN.Train(data, labels, opt);
  32.     /*
  33.     NN.Predict(data, prediction);
  34.    
  35.     error = 0;
  36.     for (size_t i = 0; i < labels.n_rows; i++)
  37.     {
  38.         if (arma::sum(arma::sum(arma::abs(prediction.col(i) - labels.col(i)))) == 0)
  39.             error++;
  40.     }
  41.    
  42.     ClassificationError = error / (double) labels.n_rows * 100;
  43.    
  44.     cout << "Classification Error: " << ClassificationError << '\n';*/
  45.    
  46.     return 0;
  47. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement