Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- const size_t rho = 10;
- // Generate 12 (2 * 6) noisy sines. A single sine contains rho
- // points/features.
- arma::cube input;
- arma::mat labelsTemp;
- GenerateNoisySines(input, labelsTemp, rho, 6);
- arma::cube labels = arma::zeros<arma::cube>(1, labelsTemp.n_cols, rho);
- for (size_t i = 0; i < labelsTemp.n_cols; ++i)
- {
- const int value = arma::as_scalar(arma::find(
- arma::max(labelsTemp.col(i)) == labelsTemp.col(i), 1)) + 1;
- labels.tube(0, i).fill(value);
- }
- /**
- * Construct a network with 1 input unit, 4 hidden units and 10 output
- * units. The hidden layer is connected to itself. The network structure
- * looks like:
- *
- * Input Hidden Output
- * Layer(1) Layer(4) Layer(10)
- * +-----+ +-----+ +-----+
- * | | | | | |
- * | +------>| +------>| |
- * | | ..>| | | |
- * +-----+ . +--+--+ +-----+
- * . .
- * . .
- * .......
- */
- Add<> add(4);
- Linear<> lookup(1, 4);
- SigmoidLayer<> sigmoidLayer;
- Linear<> linear(4, 4);
- Recurrent<>* recurrent = new Recurrent<>(add, lookup, linear,
- sigmoidLayer, rho);
- RNN<> model(rho);
- model.Add<IdentityLayer<> >();
- model.Add(recurrent);
- model.Add<Linear<> >(4, 10);
- model.Add<LogSoftMax<> >();
- StandardSGD opt(0.1, 1, input.n_cols /* 1 epoch */, -100);
- model.Train(input, labels, opt);
- // Serialize the network.
- RNN<> xmlModel(1), textModel(3), binaryModel(5);
- // Take predictions, check the output.
- arma::cube prediction, xmlPrediction, textPrediction, binaryPrediction;
- model.Predict(input, prediction);
- SerializeObjectAll(model, xmlModel, textModel, binaryModel);
- xmlModel.Predict(input, xmlPrediction);
- textModel.Predict(input, textPrediction);
- binaryModel.Predict(input, binaryPrediction);
Add Comment
Please, Sign In to add comment