mulx10

Untitled

Apr 15th, 2019
62
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.90 KB | None | 0 0
  1.   const size_t rho = 10;
  2.  
  3.   // Generate 12 (2 * 6) noisy sines. A single sine contains rho
  4.   // points/features.
  5.   arma::cube input;
  6.   arma::mat labelsTemp;
  7.   GenerateNoisySines(input, labelsTemp, rho, 6);
  8.  
  9.   arma::cube labels = arma::zeros<arma::cube>(1, labelsTemp.n_cols, rho);
  10.   for (size_t i = 0; i < labelsTemp.n_cols; ++i)
  11.   {
  12.     const int value = arma::as_scalar(arma::find(
  13.         arma::max(labelsTemp.col(i)) == labelsTemp.col(i), 1)) + 1;
  14.     labels.tube(0, i).fill(value);
  15.   }
  16.  
  17.   /**
  18.    * Construct a network with 1 input unit, 4 hidden units and 10 output
  19.    * units. The hidden layer is connected to itself. The network structure
  20.    * looks like:
  21.    *
  22.    *  Input         Hidden        Output
  23.    * Layer(1)      Layer(4)      Layer(10)
  24.    * +-----+       +-----+       +-----+
  25.    * |     |       |     |       |     |
  26.    * |     +------>|     +------>|     |
  27.    * |     |    ..>|     |       |     |
  28.    * +-----+    .  +--+--+       +-----+
  29.    *            .     .
  30.    *            .     .
  31.    *            .......
  32.    */
  33.   Add<> add(4);
  34.   Linear<> lookup(1, 4);
  35.   SigmoidLayer<> sigmoidLayer;
  36.   Linear<> linear(4, 4);
  37.   Recurrent<>* recurrent = new Recurrent<>(add, lookup, linear,
  38.       sigmoidLayer, rho);
  39.  
  40.   RNN<> model(rho);
  41.   model.Add<IdentityLayer<> >();
  42.   model.Add(recurrent);
  43.   model.Add<Linear<> >(4, 10);
  44.   model.Add<LogSoftMax<> >();
  45.  
  46.   StandardSGD opt(0.1, 1, input.n_cols /* 1 epoch */, -100);
  47.   model.Train(input, labels, opt);
  48.  
  49.   // Serialize the network.
  50.   RNN<> xmlModel(1), textModel(3), binaryModel(5);
  51.  
  52.   // Take predictions, check the output.
  53.   arma::cube prediction, xmlPrediction, textPrediction, binaryPrediction;
  54.   model.Predict(input, prediction);
  55.   SerializeObjectAll(model, xmlModel, textModel, binaryModel);
  56.   xmlModel.Predict(input, xmlPrediction);
  57.   textModel.Predict(input, textPrediction);
  58.   binaryModel.Predict(input, binaryPrediction);
Add Comment
Please, Sign In to add comment