Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <fstream>
- #include <mlpack/core.hpp>
- #include <mlpack/methods/ann/rnn.hpp>
- #include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
- #include <mlpack/methods/ann/layer/layer.hpp>
- /*
- Loads the data (as a CSV file) and reshapes it into a usable state.
- */
- arma::cube LoadData(std::string file){
- /*
- row = feature = input size = 4 features
- col = datapoints = rho = 16 datapoints per slice (timestep)
- slice = timestep = 2 slices (timesteps)
- */
- arma::mat filedata;
- mlpack::data::Load(file, filedata, true);
- long long unsigned int amount_slices = (filedata.n_cols - 16 + 1);
- arma::cube temp(4,16,amount_slices);
- for(unsigned int i = 0; i < amount_slices; ++i){
- temp.slice(i) = filedata.submat(0, i, 3, i+15);
- }
- return temp;
- }
- int main(){
- arma::cube input = LoadData("data.csv"); //print this cube if you want, there should be no problem with it
- mlpack::ann::RNN<mlpack::ann::MeanSquaredError<> > model(2);
- model.Add<mlpack::ann::Linear<> >(4, 10);
- model.Add<mlpack::ann::GRU<> >(10, 10);
- model.Add<mlpack::ann::Linear<> >(10, 3);
- model.Parameters().randu();
- arma::cube predictions(4,16,2), predictors(3,16,2);
- model.Train(input, predictors); //here's where it errors
- model.Predict(input, predictions); //here's where it errors
- return 0;
- }
Add Comment
Please, Sign In to add comment