Guest User

Untitled

a guest
Mar 23rd, 2018
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.37 KB | None | 0 0
  1. package com.research.skindetector;
  2.  
  3. import org.datavec.api.records.listener.impl.LogRecordListener;
  4. import org.datavec.api.split.FileSplit;
  5. import org.datavec.image.loader.NativeImageLoader;
  6. import org.deeplearning4j.api.storage.StatsStorage;
  7. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  8. import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
  9. import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
  10. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  11. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  12. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  13. import org.deeplearning4j.nn.conf.inputs.InputType;
  14. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
  15. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  16. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  17. import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
  18. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  19. import org.deeplearning4j.nn.weights.WeightInit;
  20. import org.deeplearning4j.ui.api.UIServer;
  21. import org.deeplearning4j.ui.stats.StatsListener;
  22. import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
  23. import org.nd4j.linalg.activations.Activation;
  24. import org.nd4j.linalg.dataset.DataSet;
  25. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  26. import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
  27. import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
  28. import org.nd4j.linalg.factory.Nd4j;
  29. import org.nd4j.linalg.learning.config.Nesterovs;
  30. import org.nd4j.linalg.lossfunctions.LossFunctions;
  31. import org.slf4j.Logger;
  32. import org.slf4j.LoggerFactory;
  33.  
  34. import java.io.File;
  35. import java.io.IOException;
  36. import java.util.Random;
  37.  
  38. /**
  39. * The heart of the program that calls everything that needs to run.
  40. *
  41. * @author Ronan Konishi
  42. * @version 1.0
  43. */
  44. public class Main {
  45. static int rngseed;
  46. static Random ranNumGen;
  47. static JsonImageRecordReader recordReader;
  48.  
  49. static int height = 28;
  50. static int width = 28;
  51. static int nChannels = 3; // Number of input channels
  52. static int outputNum = 2; // The number of possible outcomes
  53. static int iterations = 1; // Number of training iterations
  54. static int seed = 123; //
  55. static int numEpochs = 1; //number of iterations through entire dataset
  56.  
  57. private static Logger log = LoggerFactory.getLogger(Main.class);
  58.  
  59. public static void main(String[] args) throws IOException {
  60. int batchSize = 1000;
  61. File mixedData = new File("C:\\Users\\ronan\\Desktop\\testsmall\\mixedData\\");
  62. File trainData = new File("C:\\Users\\ronan\\Desktop\\testsmall\\trainData\\");
  63. File testData = new File("C:\\Users\\ronan\\Desktop\\testsmall\\testData\\");
  64. NeuralNetwork network = new NeuralNetwork(mixedData, trainData, testData, rngseed, height, width, nChannels, batchSize, outputNum);
  65. MultiLayerNetwork net = network.getNet();
  66.  
  67. // log.info("*****TRAIN MODEL********");
  68. network.train(numEpochs);
  69.  
  70. UIServer uiServer = UIServer.getInstance();
  71.  
  72. StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File) - see UIStorageExample
  73. int listenerFrequency = 1;
  74. net.setListeners(new StatsListener(statsStorage, listenerFrequency));
  75.  
  76. uiServer.attach(statsStorage);
  77.  
  78. net.fit(network.getTrainIter());
  79.  
  80. System.out.println("DONE");
  81. }
  82. }
Add Comment
Please, Sign In to add comment