Guest User

Untitled

a guest
Dec 13th, 2018
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.61 KB | None | 0 0
  1. public class NetworkManager {
  2.  
  3. private final List<IndexWord> labels;
  4. private final ComputationGraph network;
  5. private final InMemoryStatsStorage statsStorage;
  6.  
  7. public NetworkManager(List<IndexWord> labels) {
  8. this.labels = labels;
  9. network = buildNetwork();
  10.  
  11. statsStorage = new InMemoryStatsStorage();
  12. UIServer.getInstance().attach(statsStorage);
  13. }
  14.  
  15. private ComputationGraph buildNetwork() {
  16. ComputationGraph pretrainedNet;
  17. try {
  18. pretrainedNet = (ComputationGraph) VGG16.builder().build().initPretrained(PretrainedType.IMAGENET);
  19. } catch (IOException e) {
  20. throw new RuntimeException("Failed to load pre-trained network", e);
  21. }
  22.  
  23. final FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder()
  24. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  25. .updater(new Nesterovs(1e-4, 0.5))
  26. .build();
  27. ComputationGraph transferGraph = new TransferLearning.GraphBuilder(pretrainedNet)
  28. .fineTuneConfiguration(fineTuneConfiguration)
  29. .setFeatureExtractor("fc2") // freeze this and below
  30. .removeVertexKeepConnections("predictions")
  31. .addLayer("predictions",
  32. new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
  33. .nIn(4096).nOut(labels.size())
  34. .weightInit(WeightInit.ZERO)
  35. .activation(Activation.SOFTMAX)
  36. .build(), "fc2")
  37. .build();
  38.  
  39. System.out.println("Transfer model:");
  40. System.out.println(transferGraph.summary());
  41.  
  42. return transferGraph;
  43. }
  44.  
  45. public void train(DataSetIterator trainIterator, DataSetIterator testIterator, int epochs) {
  46. TransferLearningHelper transferHelper = new TransferLearningHelper(network);
  47. transferHelper.unfrozenGraph().setListeners(new StatsListener(statsStorage));
  48.  
  49. System.out.println("Going to featurize images...");
  50. DataSetIterator featurizedTrain = featurize(trainIterator, transferHelper);
  51. DataSetIterator featurizedTest = featurize(testIterator, transferHelper);
  52.  
  53. List<String> labelStrings = labels.stream().map(IndexWord::getLemma).collect(toList());
  54.  
  55. Evaluation evalBefore = transferHelper.unfrozenGraph().evaluate(featurizedTest, labelStrings);
  56. System.out.println(evalBefore.stats(false, false));
  57. featurizedTest.reset();
  58.  
  59. for (int i = 0; i < epochs; i++) {
  60. System.out.println("Starting training epoch " + i);
  61.  
  62. transferHelper.fitFeaturized(featurizedTrain);
  63. featurizedTrain.reset();
  64.  
  65. Evaluation eval = transferHelper.unfrozenGraph().evaluate(featurizedTest, labelStrings);
  66. System.out.println(eval.stats(false, false));
  67. featurizedTest.reset();
  68. }
  69.  
  70. System.out.println("Training complete");
  71. }
  72.  
  73. private DataSetIterator featurize(DataSetIterator dataSetIterator, TransferLearningHelper transferHelper) {
  74. // featurize ahead of time rather than lazily to avoid issues with multiple workspaces
  75. List<DataSet> featurizeds = new LinkedList<>();
  76. while (dataSetIterator.hasNext()) {
  77. DataSet dataSet = dataSetIterator.next();
  78. featurizeds.add(transferHelper.featurize(dataSet));
  79. }
  80. return new CachingDataSetIterator(
  81. new ListDataSetIterator<>(featurizeds, dataSetIterator.batch()),
  82. new InMemoryDataSetCache()
  83. );
  84. }
  85. }
Add Comment
Please, Sign In to add comment