Guest User

Untitled

a guest
May 22nd, 2018
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.05 KB | None | 0 0
  1. // some common parameters
  2. NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
  3. builder.seed(123);
  4. builder.biasInit(0);
  5. builder.miniBatch(false);
  6. builder.updater(new RmsProp(0.001));
  7. builder.weightInit(WeightInit.XAVIER);
  8.  
  9. ListBuilder listBuilder = builder.list();
  10.  
  11. // first difference, for rnns we need to use GravesLSTM.Builder
  12. for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
  13. GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder();
  14. hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);
  15. hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
  16. // adopted activation function from GravesLSTMCharModellingExample
  17. // seems to work well with RNNs
  18. hiddenLayerBuilder.activation(Activation.TANH);
  19. listBuilder.layer(i, hiddenLayerBuilder.build());
  20. }
  21.  
  22. // we need to use RnnOutputLayer for our RNN
  23. RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
  24. // softmax normalizes the output neurons, the sum of all outputs is 1
  25. // this is required for our sampleFromDistribution-function
  26. outputLayerBuilder.activation(Activation.SOFTMAX);
  27. outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
  28. outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
  29. listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
  30.  
  31. // finish builder
  32. listBuilder.pretrain(false);
  33. listBuilder.backprop(true);
  34.  
  35. // create network
  36. MultiLayerConfiguration conf = listBuilder.build();
  37. net = new MultiLayerNetwork(conf);
  38.  
  39. net.init();
  40. //net.setListeners(new ScoreIterationListener(1));
  41.  
  42. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  43. // some epochs
  44. for (int epoch = 0; epoch < 1000; epoch++) {
  45.  
  46. System.out.println("Epoch " + epoch);
  47.  
  48. provideUIServer();
  49.  
  50. // train the data
  51. net.fit(trainingData);
  52. System.out.println("batch " + net.batchSize());
  53.  
  54. // clear current stance from the last example
  55. net.rnnClearPreviousState();
  56.  
  57. }
Add Comment
Please, Sign In to add comment