Guest User

Untitled

a guest
Nov 23rd, 2017
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.11 KB | None | 0 0
  1. package ai.skymind.training.labs;
  2.  
  3. import org.apache.log4j.BasicConfigurator;
  4. import org.deeplearning4j.api.storage.StatsStorage;
  5. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  6. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  7. import org.deeplearning4j.nn.conf.Updater;
  8. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  9. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  10. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  11. import org.deeplearning4j.nn.weights.WeightInit;
  12. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  13. import org.deeplearning4j.ui.api.UIServer;
  14. import org.deeplearning4j.ui.stats.StatsListener;
  15. import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
  16. import org.nd4j.linalg.activations.Activation;
  17. import org.nd4j.linalg.api.ndarray.INDArray;
  18. import org.nd4j.linalg.factory.Nd4j;
  19. import org.nd4j.linalg.lossfunctions.LossFunctions;
  20. import org.slf4j.Logger;
  21. import org.slf4j.LoggerFactory;
  22.  
  23. import java.util.Random;
  24.  
  25. /**
  26. * Built for SkyMind Training class
  27. */
  28. public class SimplestNetwork {
  29. private static Logger log = LoggerFactory.getLogger(SimplestNetwork.class);
  30. public static void main(String[] args) throws Exception{
  31. BasicConfigurator.configure();
  32. /*
  33. Most Basic NN that takes a single input
  34. */
  35.  
  36. int seed = 123; // consistent Random Numbers needed for testing, Initial weights are Randomized
  37. Random rng = new Random(seed);
  38.  
  39. int nEpochs = 500; //Number of epochs (full passes of the data)
  40.  
  41. double learningRate = 0.005; // How Fast to adjust weights to minimize error
  42. // Start with Learning Rate of 0.005
  43.  
  44. int numInputs = 1; // number of input nodes
  45.  
  46. int numOutputs = 1; // number of output nodes
  47.  
  48. int nHidden = 5; // number of hidden nodes
  49. /*
  50. Create our input values and expected output values
  51. All data in all Neural Networks are represented as
  52. Numerical arrays, Normalization between 0 and 1 allows for better training
  53. */
  54.  
  55. INDArray input = Nd4j.create(new float[]{(float) 0.5},new int[]{1,1}); // Our input value
  56. INDArray output = Nd4j.create(new float[]{(float) 0.8},new int[]{1,1}); // expected output
  57. log.info("******" + input.toString() + "*********" );
  58.  
  59. /*
  60. Build a MuliLayer Network to train on our dataset
  61. */
  62.  
  63.  
  64. MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
  65. .seed(seed)
  66. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) // most commonly used Optimization algo
  67. .learningRate(learningRate)
  68. .weightInit(WeightInit.XAVIER) // Xavier is a weight randomizer optimized for NN
  69. .updater(Updater.NESTEROVS).momentum(0.09) // How to update the weights start with momentum of 0.09
  70. .list()
  71. .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(nHidden)
  72. .activation(Activation.TANH)
  73. .build())
  74. .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
  75. .activation(Activation.IDENTITY)
  76. .nIn(nHidden).nOut(numOutputs).build())
  77. .pretrain(false).backprop(true).build()
  78. );
  79. model.init();
  80.  
  81. /*
  82. Create a web based UI server to show progress as the network trains
  83. The Listeners for the model are set here as well
  84. One listener to pass stats to the UI
  85. and a Listener to pass progress info to the console
  86. */
  87.  
  88. UIServer uiServer = UIServer.getInstance();
  89. StatsStorage statsStorage = new InMemoryStatsStorage();
  90. model.setListeners(new StatsListener(statsStorage),new ScoreIterationListener(1));
  91. uiServer.attach(statsStorage);
  92. /*
  93. ParamAndGradientIterationListener pgl = ParamAndGradientIterationListener.builder()
  94. .printHeader(true)
  95. .delimiter("|")
  96. .outputToConsole(true)
  97. .printMean(true)
  98. .iterations(1)
  99. .build();
  100.  
  101. model.setListeners(pgl);
  102. */
  103.  
  104.  
  105. //UIServer uiServer = UIServer.getInstance();
  106.  
  107. //Configure where the network information (gradients, activations, score vs. time etc) is to be stored
  108. //Then add the StatsListener to collect this information from the network, as it trains
  109. //StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File) - see UIStorageExample
  110. //int listenerFrequency = 1;
  111. //model.setListeners(new StatsListener(statsStorage, listenerFrequency));
  112.  
  113. //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
  114. //uiServer.attach(statsStorage);
  115.  
  116.  
  117. //Train the network on the full data set, and evaluate in periodically
  118. for( int i=0; i<nEpochs; i++ ){
  119. model.fit(input,output);
  120. INDArray params = model.params();
  121. System.out.println(params);
  122. INDArray output2 = model.output(input);
  123. log.info(output2.toString());
  124. Thread.sleep(100);
  125. }
  126.  
  127. }
  128. }
Add Comment
Please, Sign In to add comment