Guest User

Untitled

a guest
Sep 30th, 2020
101
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.49 KB | None | 0 0
  1. /* *****************************************************************************
  2. * Copyright (c) 2020 Konduit K.K.
  3. * Copyright (c) 2015-2019 Skymind, Inc.
  4. *
  5. * This program and the accompanying materials are made available under the
  6. * terms of the Apache License, Version 2.0 which is available at
  7. * https://www.apache.org/licenses/LICENSE-2.0.
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  11. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  12. * License for the specific language governing permissions and limitations
  13. * under the License.
  14. *
  15. * SPDX-License-Identifier: Apache-2.0
  16. ******************************************************************************/
  17.  
  18. package org.deeplearning4j.examples.quickstart.modeling.feedforward.classification;
  19. import java.util.Date;
  20. import org.datavec.api.records.reader.RecordReader;
  21. import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
  22. import org.datavec.api.split.FileSplit;
  23. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  24. import org.deeplearning4j.datasets.iterator.AsyncShieldDataSetIterator;
  25. import org.deeplearning4j.examples.utils.DownloaderUtility;
  26. import org.deeplearning4j.examples.utils.PlotUtil;
  27. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  28. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  29. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  30. import org.deeplearning4j.nn.conf.layers.DropoutLayer;
  31. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  32. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  33. import org.deeplearning4j.nn.weights.WeightInit;
  34. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  35. import org.nd4j.evaluation.classification.Evaluation;
  36. import org.nd4j.linalg.activations.Activation;
  37. import org.nd4j.linalg.api.ndarray.INDArray;
  38. import org.nd4j.linalg.dataset.DataSet;
  39. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  40. import org.nd4j.linalg.learning.config.Nesterovs;
  41. import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
  42.  
  43. import java.io.File;
  44. import java.util.concurrent.TimeUnit;
  45.  
  46. /**
  47. * "Linear" Data Classification Example
  48. * <p>
  49. * Based on the data from Jason Baldridge:
  50. * https://github.com/jasonbaldridge/try-tf/tree/master/simdata
  51. *
  52. * @author Josh Patterson
  53. * @author Alex Black (added plots)
  54. */
  55. @SuppressWarnings("DuplicatedCode")
  56. public class scs {
  57.  
  58. public static boolean visualize = true;
  59. public static String dataLocalPath;
  60.  
  61. public static void main(String[] args) throws Exception {
  62. int seed = 123;
  63. double learningRate = 0.01;
  64. int batchSize = 25;
  65. int nEpochs = 20;
  66.  
  67. int numInputs = 7;
  68. int numOutputs = 6;
  69. int numHiddenNodes = 20;
  70.  
  71. //dataLocalPath = ''
  72. //Load the training data:
  73. RecordReader rr = new CSVRecordReader();
  74. rr.initialize(new FileSplit(new File( "scs_TRAIN.csv")));
  75. DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 7,6);
  76. DataSetIterator trainI = new AsyncShieldDataSetIterator(trainIter);
  77.  
  78. //Load the test/evaluation data:
  79. RecordReader rrTest = new CSVRecordReader();
  80. rrTest.initialize(new FileSplit(new File("scs_TEST.csv")));
  81. DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 7,6);
  82. DataSetIterator testI = new AsyncShieldDataSetIterator(testIter);
  83.  
  84. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  85. .seed(seed)
  86. .weightInit(WeightInit.XAVIER)
  87. .updater(new Nesterovs(learningRate, 0.9))
  88. .list()
  89. .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
  90. .activation(Activation.RELU)
  91. .dropOut(0.75)
  92. .build())
  93. //.layer(new DropoutLayer(0.5))
  94. .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
  95. .activation(Activation.SOFTMAX)
  96. .nIn(numHiddenNodes).nOut(numOutputs).build())
  97. .build();
  98. Date date_start = new Date();
  99. System.out.println(date_start.toString());
  100. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  101. model.init();
  102. model.setListeners(new ScoreIterationListener(100)); //Print score every 100 parameter updates
  103.  
  104. model.fit(trainI, nEpochs);
  105.  
  106.  
  107. System.out.println("Evaluate model....");
  108. Evaluation eval = new Evaluation(numOutputs);
  109. while (testI.hasNext()) {
  110. DataSet t = testI.next();
  111. INDArray features = t.getFeatures();
  112. INDArray labels = t.getLabels();
  113. INDArray predicted = model.output(features, false);
  114. eval.eval(labels, predicted);
  115. }
  116. //An alternate way to do the above loop
  117. //Evaluation evalResults = model.evaluate(testIter);
  118.  
  119. //Print the evaluation statistics
  120. System.out.println(eval.stats());
  121. Date date_end = new Date();
  122. System.out.println("\n****************Example finished********************");
  123. System.out.println("\n******Time start******\n");
  124. System.out.println(date_start.toString());
  125. System.out.println("\n******Time end******\n");
  126. System.out.println(date_end.toString());
  127. //Training is complete. Code that follows is for plotting the data & predictions only
  128.  
  129. }
  130.  
  131. }
  132.  
Add Comment
Please, Sign In to add comment