daily pastebin goal
40%
SHARE
TWEET

Untitled

a guest Dec 13th, 2018 56 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. public class NetworkManager {
  2.  
  3.     private final List<IndexWord> labels;
  4.     private final ComputationGraph network;
  5.     private final InMemoryStatsStorage statsStorage;
  6.  
  7.     public NetworkManager(List<IndexWord> labels) {
  8.         this.labels = labels;
  9.         network = buildNetwork();
  10.  
  11.         statsStorage = new InMemoryStatsStorage();
  12.         UIServer.getInstance().attach(statsStorage);
  13.     }
  14.  
  15.     private ComputationGraph buildNetwork() {
  16.         ComputationGraph pretrainedNet;
  17.         try {
  18.             pretrainedNet = (ComputationGraph) VGG16.builder().build().initPretrained(PretrainedType.IMAGENET);
  19.         } catch (IOException e) {
  20.             throw new RuntimeException("Failed to load pre-trained network", e);
  21.         }
  22.  
  23.         final FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder()
  24.                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  25.                 .updater(new Nesterovs(1e-4, 0.5))
  26.                 .build();
  27.         ComputationGraph transferGraph = new TransferLearning.GraphBuilder(pretrainedNet)
  28.                 .fineTuneConfiguration(fineTuneConfiguration)
  29.                 .setFeatureExtractor("fc2") // freeze this and below
  30.                 .removeVertexKeepConnections("predictions")
  31.                 .addLayer("predictions",
  32.                         new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
  33.                                 .nIn(4096).nOut(labels.size())
  34.                                 .weightInit(WeightInit.ZERO)
  35.                                 .activation(Activation.SOFTMAX)
  36.                                 .build(), "fc2")
  37.                 .build();
  38.  
  39.         System.out.println("Transfer model:");
  40.         System.out.println(transferGraph.summary());
  41.  
  42.         return transferGraph;
  43.     }
  44.  
  45.     public void train(DataSetIterator trainIterator, DataSetIterator testIterator, int epochs) {
  46.         TransferLearningHelper transferHelper = new TransferLearningHelper(network);
  47.         transferHelper.unfrozenGraph().setListeners(new StatsListener(statsStorage));
  48.  
  49.         System.out.println("Going to featurize images...");
  50.         DataSetIterator featurizedTrain = featurize(trainIterator, transferHelper);
  51.         DataSetIterator featurizedTest = featurize(testIterator, transferHelper);
  52.  
  53.         List<String> labelStrings = labels.stream().map(IndexWord::getLemma).collect(toList());
  54.  
  55.         Evaluation evalBefore = transferHelper.unfrozenGraph().evaluate(featurizedTest, labelStrings);
  56.         System.out.println(evalBefore.stats(false, false));
  57.         featurizedTest.reset();
  58.  
  59.         for (int i = 0; i < epochs; i++) {
  60.             System.out.println("Starting training epoch " + i);
  61.  
  62.             transferHelper.fitFeaturized(featurizedTrain);
  63.             featurizedTrain.reset();
  64.  
  65.             Evaluation eval = transferHelper.unfrozenGraph().evaluate(featurizedTest, labelStrings);
  66.             System.out.println(eval.stats(false, false));
  67.             featurizedTest.reset();
  68.         }
  69.  
  70.         System.out.println("Training complete");
  71.     }
  72.  
  73.     private DataSetIterator featurize(DataSetIterator dataSetIterator, TransferLearningHelper transferHelper) {
  74.         // featurize ahead of time rather than lazily to avoid issues with multiple workspaces
  75.         List<DataSet> featurizeds = new LinkedList<>();
  76.         while (dataSetIterator.hasNext()) {
  77.             DataSet dataSet = dataSetIterator.next();
  78.             featurizeds.add(transferHelper.featurize(dataSet));
  79.         }
  80.         return new CachingDataSetIterator(
  81.                 new ListDataSetIterator<>(featurizeds, dataSetIterator.batch()),
  82.                 new InMemoryDataSetCache()
  83.         );
  84.     }
  85. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top