Guest User

Untitled

a guest
Mar 23rd, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 16.48 KB | None | 0 0
  1. package com.research.skindetector;
  2.  
  3. import org.datavec.api.split.FileSplit;
  4. import org.datavec.image.loader.NativeImageLoader;
  5. import org.deeplearning4j.api.storage.StatsStorage;
  6. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  7. import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
  8. import org.deeplearning4j.eval.Evaluation;
  9. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  10. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  11. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  12. import org.deeplearning4j.nn.conf.WorkspaceMode;
  13. import org.deeplearning4j.nn.conf.inputs.InputType;
  14. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
  15. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  16. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  17. import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
  18. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  19. import org.deeplearning4j.nn.weights.WeightInit;
  20. import org.deeplearning4j.parallelism.ParallelWrapper;
  21. import org.deeplearning4j.ui.api.UIServer;
  22. import org.deeplearning4j.ui.stats.StatsListener;
  23. import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
  24. import org.deeplearning4j.util.ModelSerializer;
  25. import org.nd4j.linalg.activations.Activation;
  26. import org.nd4j.linalg.api.ndarray.INDArray;
  27. import org.nd4j.linalg.dataset.DataSet;
  28. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  29. import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
  30. import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
  31. import org.nd4j.linalg.factory.Nd4j;
  32. import org.nd4j.linalg.learning.config.Nesterovs;
  33. import org.nd4j.linalg.lossfunctions.LossFunctions;
  34. import org.slf4j.Logger;
  35. import org.slf4j.LoggerFactory;
  36.  
  37. import java.io.File;
  38. import java.io.IOException;
  39. import java.nio.file.Files;
  40. import java.util.Random;
  41.  
  42. /**
  43. * Convolutional Neural Network class that applies Supervised Learning.
  44. *
  45. * This is the main utility that builds, trains, and evaluates the neural network.
  46. *
  47. * Based on deeplearning4j open source library and tutorials.
  48. *
  49. * @author Ronan Konishi
  50. * @version 1.0
  51. *
  52. */
  53. public class NeuralNetwork {
  54. private static Logger log = LoggerFactory.getLogger(Main.class);
  55. File trainData, testData;
  56. int rngseed, height, width, channels, batchSize, outputNum;
  57. String netPath;
  58. Random ranNumGen;
  59. MultiLayerNetwork model;
  60. JsonImageRecordReader recordReader;
  61. DataNormalization scaler;
  62. DataSetIterator test_iter, train_iter;
  63. // AsyncDataSetIterator train_iter, test_iter;
  64. // DataSetIterator iter;
  65.  
  66. /**
  67. * Constructor for non distinguished training and testing data.
  68. * Also for if needing to create a neural network.
  69. *
  70. * @param mixedData Path to file with mixed data
  71. * @param rngseed Integer that allows for constant random generated value
  72. * @param height The height of image in pixels
  73. * @param width The width of image in pixels
  74. * @param channels The number of channels (e.g. 1 for gray scaled and 3 for RGB)
  75. * @param batchSize
  76. * @param outputNum The number of nodes in the output layer
  77. */
  78. public NeuralNetwork(File mixedData, File trainData, File testData, int rngseed, int height, int width, int channels, int batchSize, int outputNum) throws IOException {
  79. this.trainData = trainData;
  80. this.testData = testData;
  81. this.rngseed = rngseed;
  82. this.height = height;
  83. this.width = width;
  84. this.channels = channels;
  85. this.batchSize = batchSize;
  86. this.outputNum = outputNum;
  87. this.netPath = netPath;
  88. ranNumGen = new Random(rngseed);
  89. vectorization();
  90. dataSplitter(mixedData, trainData, testData);
  91. log.info("Building Neural Network from scratch...");
  92. buildNet();
  93. }
  94.  
  95. /**
  96. * Constructor for non distinguished training and testing data.
  97. * Also for if importing an already built neural network.
  98. *
  99. * @param mixedData Path to file with mixed data
  100. * @param rngseed Integer that allows for constant random generated value
  101. * @param height The height of image in pixels
  102. * @param width The width of image in pixels
  103. * @param channels The number of channels (e.g. 1 for gray scaled and 3 for RGB)
  104. * @param batchSize
  105. * @param outputNum The number of nodes in the output layer
  106. * @param netPath The path from which the neural network is being imported
  107. */
  108. public NeuralNetwork(File mixedData, File trainData, File testData, int rngseed, int height, int width, int channels, int batchSize, int outputNum, String netPath) throws IOException {
  109. this.trainData = trainData;
  110. this.testData = testData;
  111. this.rngseed = rngseed;
  112. this.height = height;
  113. this.width = width;
  114. this.channels = channels;
  115. this.batchSize = batchSize;
  116. this.outputNum = outputNum;
  117. this.netPath = netPath;
  118. ranNumGen = new Random(rngseed);
  119. vectorization();
  120. dataSplitter(mixedData, trainData, testData);
  121. log.info("Building Neural Network from import...");
  122. loadNet(netPath);
  123. }
  124.  
  125. /**
  126. * Constructor for preemptively defined training and testing data.
  127. * Also for if needing to create a neural network.
  128. *
  129. * @param trainData Path to file with training data
  130. * @param testData Path to file with
  131. * @param rngseed Integer that allows for constant random generated value
  132. * @param height The height of image in pixels
  133. * @param width The width of image in pixels
  134. * @param channels The number of channels (e.g. 1 for gray scaled and 3 for RGB)
  135. * @param batchSize
  136. * @param outputNum The number of nodes in the output layer
  137. */
  138. public NeuralNetwork(File trainData, File testData, int rngseed, int height, int width, int channels, int batchSize, int outputNum) throws IOException {
  139. this.trainData = trainData;
  140. this.testData = testData;
  141. this.rngseed = rngseed;
  142. this.height = height;
  143. this.width = width;
  144. this.channels = channels;
  145. this.batchSize = batchSize;
  146. this.outputNum = outputNum;
  147. this.netPath = netPath;
  148. ranNumGen = new Random(rngseed);
  149. vectorization();
  150. log.info("Building Neural Network from scratch...");
  151. buildNet();
  152. }
  153.  
  154. /**
  155. * Constructor for preemptively defined training and testing data.
  156. * Also for if importing an already built neural network.
  157. *
  158. * @param trainData Path to file with training data
  159. * @param testData Path to file with
  160. * @param rngseed Integer that allows for constant random generated value
  161. * @param height The height of image in pixels
  162. * @param width The width of image in pixels
  163. * @param channels The number of channels (e.g. 1 for gray scaled and 3 for RGB)
  164. * @param batchSize
  165. * @param outputNum The number of nodes in the output layer
  166. * @param netPath The path from which the neural network is being imported
  167. */
  168. public NeuralNetwork(File trainData, File testData, int rngseed, int height, int width, int channels, int batchSize, int outputNum, String netPath) throws IOException {
  169. this.trainData = trainData;
  170. this.testData = testData;
  171. this.rngseed = rngseed;
  172. this.height = height;
  173. this.width = width;
  174. this.channels = channels;
  175. this.batchSize = batchSize;
  176. this.outputNum = outputNum;
  177. this.netPath = netPath;
  178. ranNumGen = new Random(rngseed);
  179. vectorization();
  180. log.info("Building Neural Network from import...");
  181. loadNet(netPath);
  182. }
  183.  
  184. private void vectorization(){
  185. JsonPathLabelGenerator label = new JsonPathLabelGenerator();
  186. recordReader = new JsonImageRecordReader(height, width, channels, label);
  187. // recordReader.setListeners(new LogRecordListener());
  188. }
  189.  
  190. /**
  191. * Builds a neural network using gradient descent with regularization algorithm.
  192. */
  193. private MultiLayerNetwork buildNet() {
  194. int layer1 = 1000;
  195.  
  196. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  197. .seed(rngseed)
  198. .iterations(1) // Training iterations as above
  199. .regularization(true).l2(0.0005)
  200. .learningRate(0.01)
  201. .weightInit(WeightInit.XAVIER)
  202. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  203. .updater(new Nesterovs(0.9))
  204. .list()
  205. .layer(0, new ConvolutionLayer.Builder(5, 5)
  206. //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
  207. .nIn(channels) //3
  208. .stride(1, 1)
  209. .nOut(20)
  210. .activation(Activation.LEAKYRELU)
  211. .build())
  212. .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
  213. .kernelSize(2,2)
  214. .stride(2,2)
  215. .build())
  216. .layer(2, new ConvolutionLayer.Builder(5, 5)
  217. //Note that nIn need not be specified in later layers
  218. .stride(1, 1)
  219. .nOut(50)
  220. .activation(Activation.LEAKYRELU)
  221. .build())
  222. .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
  223. .kernelSize(2,2)
  224. .stride(2,2)
  225. .build())
  226. .layer(4, new DenseLayer.Builder().activation(Activation.LEAKYRELU).nOut(500).build())
  227. .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
  228. .nOut(outputNum)
  229. .activation(Activation.SOFTMAX)
  230. .build())
  231. .setInputType(InputType.convolutionalFlat(height,width,channels))
  232. .backprop(true).pretrain(false).build();
  233.  
  234. model = new MultiLayerNetwork(conf);
  235. model.init();
  236.  
  237. return model;
  238. }
  239.  
  240. public MultiLayerNetwork getNet(){
  241. return model;
  242. }
  243.  
  244. public void UIenable(){
  245. UIServer uiServer = UIServer.getInstance();
  246. StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later
  247. int listenerFrequency = 1;
  248. model.setListeners(new StatsListener(statsStorage, listenerFrequency));
  249. uiServer.attach(statsStorage);
  250. }
  251.  
  252. /**
  253. * Trains the neural network with the training data.
  254. *
  255. * @param numEpochs Determines the number of times the model iterates through the training data set
  256. */
  257. public void train(int numEpochs) throws IOException {
  258. //UI enable
  259. // UIenable();
  260.  
  261. //if trainingReady is true and evaluatingReady is false
  262. FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, this.ranNumGen);
  263. // recordReaderInit(train);
  264.  
  265. recordReader.initialize(train);
  266. DataSetIterator temp_iter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
  267. scaler = new ImagePreProcessingScaler(0,1);
  268. scaler.fit(temp_iter);
  269. temp_iter.setPreProcessor(scaler);
  270. train_iter = temp_iter;
  271. // train_iter = new AsyncDataSetIterator(temp_iter);
  272.  
  273. //Displays how well neural network is training
  274. // model.setListeners(new ScoreIterationListener(10));
  275.  
  276. //disable java garbage collector
  277. // Nd4j.getMemoryManager().setAutoGcWindow(5000);
  278. // Nd4j.getMemoryManager().togglePeriodicGc(false);
  279. //
  280. // ParallelWrapper wrapper = new ParallelWrapper.Builder(model)
  281. // // DataSets prefetching options. Buffer size per worker.
  282. // .prefetchBuffer(8)
  283. // // set number of workers equal to number of GPUs.
  284. // .workers(2)
  285. // // rare averaging improves performance but might reduce model accuracy
  286. // .averagingFrequency(5)
  287. // // if set to TRUE, on every averaging model score will be reported
  288. // .reportScoreAfterAveraging(false)
  289. // // 3 options here: NONE, SINGLE, SEPARATE
  290. // .workspaceMode(WorkspaceMode.SEPARATE)
  291. // .build();
  292.  
  293. // System.out.println("Starting to fit model");
  294. // for(int i = 0; i < numEpochs; i++) {
  295. // model.fit(train_iter);
  296. // }
  297. // System.out.println("Finished fitting model");
  298. }
  299.  
  300. /**
  301. * Evaluates the neural network by running the network through the testing data set.
  302. *
  303. * @returns eval The output of the evaluation
  304. */
  305. public Evaluation evaluate() throws IOException {
  306. //if trainingReady is false and evaluatingReady is true
  307. FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, ranNumGen);
  308. // recordReaderInit(test);
  309.  
  310. recordReader.initialize(test);
  311. test_iter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
  312. scaler = new ImagePreProcessingScaler(0,1);
  313. scaler.fit(test_iter);
  314. test_iter.setPreProcessor(scaler);
  315.  
  316. Evaluation eval = new Evaluation(outputNum);
  317. //
  318. // ROC roceval = new ROC(outputNum);
  319. // model.doEvaluation(iteratorTest, eval, roceval);
  320.  
  321. while(test_iter.hasNext()) {
  322. DataSet next = test_iter.next();
  323. INDArray output = model.output(next.getFeatureMatrix());
  324. eval.eval(next.getLabels(), output);
  325. }
  326. return eval;
  327. }
  328.  
  329. /**
  330. * Saves the trained neural network
  331. *
  332. * @param filepath The path and file to which the neural network should be save to
  333. */
  334. public void saveBuild(String filepath) throws IOException {
  335. File saveLocation = new File(filepath);
  336. boolean saveUpdater = false; //want to enable retraining of data
  337. ModelSerializer.writeModel(model,saveLocation,saveUpdater);
  338. }
  339.  
  340. /**
  341. * For testing purposes. Displays images with labels from a given database.
  342. *
  343. * @param numImages The number of images to display
  344. */
  345. public void imageToLabelDisplay(int numImages, DataSetIterator iter){
  346. for (int i = 0; i < numImages; i++) {
  347. DataSet ds = iter.next();
  348. System.out.println(ds);
  349. System.out.println(iter.getLabels());
  350. }
  351. }
  352.  
  353. private void dataSplitter(File mixedDataset, File trainData, File testData) throws IOException {
  354. this.trainData = trainData;
  355. this.testData = testData;
  356.  
  357. File[] mixedData = mixedDataset.listFiles();
  358. String temp1, temp2;
  359. for(int i = 0; i < mixedData.length/2; i++){
  360. double random = Math.random();
  361. temp1 = mixedData[i*2].toString();
  362. temp2 = mixedData[i*2+1].toString();
  363.  
  364. if (random > 0.25) {
  365. Files.move(mixedData[i*2].toPath(), new File(trainData + "\\" + temp1.substring(temp1.lastIndexOf('\\')+1)).toPath());
  366. Files.move(mixedData[i*2+1].toPath(), new File(trainData + "\\" + temp2.substring(temp2.lastIndexOf('\\')+1)).toPath());
  367. } else {
  368. Files.move(mixedData[i*2].toPath(), new File(testData + "\\" + temp1.substring(temp1.lastIndexOf('\\')+1)).toPath());
  369. Files.move(mixedData[i*2+1].toPath(), new File(testData + "\\" + temp2.substring(temp2.lastIndexOf('\\')+1)).toPath());
  370. }
  371. }
  372. }
  373.  
  374. private void loadNet(String NetPath) throws IOException {
  375. File locationToSave = new File(NetPath);
  376. model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
  377. }
  378.  
  379. // /**
  380. // * Creates a record reader.
  381. // *
  382. // * @param file Name of path to the database wanting to be initialized
  383. // */
  384. // private void recordReaderInit(FileSplit file) throws IOException {
  385. //// recordReader.reset();
  386. // recordReader.initialize(file);
  387. // DataSetIterator temp_iter = new RecordRearecordReader.initialize(file);
  388. // DataSetIterator temp_iter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
  389. // scaler = new ImagePreProcessingScaler(0,1);
  390. // scaler.fit(temp_iter);
  391. // temp_iter.setPreProcessor(scaler);
  392. // iter = new AsyncDataSetIterator(temp_iter);derDataSetIterator(recordReader,batchSize,1,outputNum);
  393. // scaler = new ImagePreProcessingScaler(0,1);
  394. // scaler.fit(temp_iter);
  395. // temp_iter.setPreProcessor(scaler);
  396. // iter = new AsyncDataSetIterator(temp_iter);
  397. // }
  398.  
  399. public DataSetIterator getTrainIter(){
  400. return train_iter;
  401. }
  402. }
Add Comment
Please, Sign In to add comment