Advertisement
Caio_25

Intelij

Sep 11th, 2019
150
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.49 KB | None | 0 0
  1. package dl;
  2.  
  3. import java.io.File;
  4. import java.io.IOException;
  5. import org.datavec.api.records.reader.RecordReader;
  6. import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
  7. import org.datavec.api.split.FileSplit;
  8. import org.nd4j.linalg.activations.Activation;
  9. import org.nd4j.linalg.api.ndarray.INDArray;
  10. import org.nd4j.linalg.learning.config.Sgd;
  11. import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
  12. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  13. import org.nd4j.linalg.dataset.DataSet;
  14. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  15. import org.deeplearning4j.eval.Evaluation;
  16. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  17. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  18. import org.deeplearning4j.nn.conf.layers.Layer;
  19. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  20. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  21. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  22. import org.deeplearning4j.nn.weights.WeightInit;
  23. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  24.  
  25. import javax.xml.crypto.Data;
  26.  
  27.  
  28. public class App1_DL4J {
  29.  
  30.  
  31. static final int input_nodes = 784;
  32. static final int hidden_nodes = 200;
  33. static final int output_nodes = 10;
  34. static final double learning_rate = 0.1;
  35. static final int random_seed = 123;
  36. static final int batch_size = 32;
  37. static final int epochs = 2;
  38.  
  39. public static void main(String[] args)throws IOException, InterruptedException {
  40.  
  41.  
  42. MultiLayerConfiguration configuration = configureNeuralNetwork();
  43. MultiLayerNetwork neuralNetwork = createNeuralNetworkModel(configuration);
  44.  
  45. DataSetIterator itfile_train;
  46. String training_file_name = "data/mnist_train.csv";
  47. itfile_train = readCSV_File(training_file_name);
  48. trainNeuralNetwork(neuralNetwork, itfile_train);
  49.  
  50. String evaluete_file_name = "data/mnist_test.csv";
  51. DataSetIterator itfile_evaluete;
  52. itfile_evaluete = readCSV_File(evaluete_file_name);
  53. evalueteNeuralNetwork(neuralNetwork, itfile_evaluete);
  54.  
  55. }
  56.  
  57. private static MultiLayerConfiguration configureNeuralNetwork()
  58. {
  59. int random_seed = 123;
  60.  
  61. Layer Wih = configureWih();
  62. Layer Who = configureWho();
  63.  
  64. return (new NeuralNetConfiguration.Builder()
  65. .seed(random_seed)
  66. .updater(new Sgd(learning_rate)) // Stochastic Descending Gradient
  67. .list()
  68. .layer(0, Wih)
  69. .layer(1, Who)
  70. .pretrain(false)
  71. .backprop(true)
  72. .build());
  73.  
  74. }
  75.  
  76.  
  77. private static MultiLayerNetwork createNeuralNetworkModel(MultiLayerConfiguration configuration)
  78. {
  79. MultiLayerNetwork model = new MultiLayerNetwork(configuration);
  80. model.init();
  81. model.setListeners(new ScoreIterationListener(1));
  82.  
  83. return model;
  84. }
  85.  
  86. private static DataSetIterator readCSV_File(String nome_arquivo) throws IOException, InterruptedException
  87. {
  88. RecordReader record = new CSVRecordReader();
  89. record.initialize(new FileSplit(new File(nome_arquivo)));
  90.  
  91. return new RecordReaderDataSetIterator (record, batch_size, 0, 10);
  92.  
  93. }
  94.  
  95.  
  96. private static void trainNeuralNetwork(MultiLayerNetwork neuralNetwork, DataSetIterator itfile)
  97. {
  98. for(int i = 0; i < epochs; i++)
  99. {
  100. neuralNetwork.fit(itfile);
  101. }
  102. }
  103.  
  104. public static void evalueteNeuralNetwork(MultiLayerNetwork model, DataSetIterator itfile)
  105. {
  106.  
  107. Evaluation score = new Evaluation(10);
  108. while(itfile.hasNext()){
  109. DataSet next_data = itfile.next();
  110. INDArray array = model.output(next_data.getFeatures());
  111. score.eval(next_data.getLabels(), array);
  112. }
  113. System.out.println(score.stats());
  114.  
  115. }
  116.  
  117. private static Layer configureWih()
  118. {
  119.  
  120. return(new DenseLayer.Builder()
  121. .nIn(input_nodes)
  122. .nOut(hidden_nodes)
  123. .activation(Activation.SIGMOID)
  124. .weightInit(WeightInit.NORMAL)
  125. .build() );
  126. }
  127.  
  128.  
  129. private static Layer configureWho()
  130. {
  131. return (new OutputLayer.Builder(LossFunction.MSE)
  132. .nIn(hidden_nodes)
  133. .nOut(output_nodes)
  134. .activation(Activation.SIGMOID)
  135. .weightInit(WeightInit.NORMAL)
  136. .build() );
  137. }
  138.  
  139.  
  140.  
  141.  
  142.  
  143. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement