mulx10

Untitled

Apr 5th, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.38 KB | None | 0 0
  1. #include <iostream>
  2. #include <fstream>
  3.  
  4. #include <mlpack/core.hpp>
  5. #include <mlpack/methods/ann/rnn.hpp>
  6. #include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
  7. #include <mlpack/methods/ann/layer/layer.hpp>
  8.  
  9. /*
  10. Loads the data (as a CSV file) and reshapes it into a usable state.
  11. */
  12. arma::cube LoadData(std::string file){
  13.   /*
  14.     row = feature = input size = 4 features
  15.     col = datapoints = rho     = 16 datapoints per slice (timestep)
  16.     slice = timestep           = 2 slices (timesteps)
  17.    */
  18.   arma::mat filedata;
  19.   mlpack::data::Load(file, filedata, true);
  20.  
  21.   long long unsigned int amount_slices = (filedata.n_cols - 16 + 1);
  22.  
  23.   arma::cube temp(4,16,amount_slices);
  24.   for(unsigned int i = 0; i < amount_slices; ++i){
  25.     temp.slice(i) = filedata.submat(0, i, 3, i+15);
  26.   }
  27.  
  28.   return temp;
  29. }
  30.  
  31.  
  32. int main(){
  33.   arma::cube input = LoadData("data.csv"); //print this cube if you want, there should be no problem with it
  34.  
  35.   mlpack::ann::RNN<mlpack::ann::MeanSquaredError<> > model(2);  
  36.  
  37.   model.Add<mlpack::ann::Linear<> >(4, 10);
  38.   model.Add<mlpack::ann::GRU<> >(10, 10);
  39.   model.Add<mlpack::ann::Linear<> >(10, 3);
  40.  
  41.   model.Parameters().randu();
  42.  
  43.   arma::cube predictions(4,16,2), predictors(3,16,2);
  44.   model.Train(input, predictors); //here's where it errors
  45.   model.Predict(input, predictions); //here's where it errors
  46.  
  47.   return 0;
  48. }
Add Comment
Please, Sign In to add comment