Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public class NetworkManager {
- private final List<IndexWord> labels;
- private final ComputationGraph network;
- private final InMemoryStatsStorage statsStorage;
- public NetworkManager(List<IndexWord> labels) {
- this.labels = labels;
- network = buildNetwork();
- statsStorage = new InMemoryStatsStorage();
- UIServer.getInstance().attach(statsStorage);
- }
- private ComputationGraph buildNetwork() {
- ComputationGraph pretrainedNet;
- try {
- pretrainedNet = (ComputationGraph) VGG16.builder().build().initPretrained(PretrainedType.IMAGENET);
- } catch (IOException e) {
- throw new RuntimeException("Failed to load pre-trained network", e);
- }
- final FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder()
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(new Nesterovs(1e-4, 0.5))
- .build();
- ComputationGraph transferGraph = new TransferLearning.GraphBuilder(pretrainedNet)
- .fineTuneConfiguration(fineTuneConfiguration)
- .setFeatureExtractor("fc2") // freeze this and below
- .removeVertexKeepConnections("predictions")
- .addLayer("predictions",
- new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
- .nIn(4096).nOut(labels.size())
- .weightInit(WeightInit.ZERO)
- .activation(Activation.SOFTMAX)
- .build(), "fc2")
- .build();
- System.out.println("Transfer model:");
- System.out.println(transferGraph.summary());
- return transferGraph;
- }
- public void train(DataSetIterator trainIterator, DataSetIterator testIterator, int epochs) {
- TransferLearningHelper transferHelper = new TransferLearningHelper(network);
- transferHelper.unfrozenGraph().setListeners(new StatsListener(statsStorage));
- System.out.println("Going to featurize images...");
- DataSetIterator featurizedTrain = featurize(trainIterator, transferHelper);
- DataSetIterator featurizedTest = featurize(testIterator, transferHelper);
- List<String> labelStrings = labels.stream().map(IndexWord::getLemma).collect(toList());
- Evaluation evalBefore = transferHelper.unfrozenGraph().evaluate(featurizedTest, labelStrings);
- System.out.println(evalBefore.stats(false, false));
- featurizedTest.reset();
- for (int i = 0; i < epochs; i++) {
- System.out.println("Starting training epoch " + i);
- transferHelper.fitFeaturized(featurizedTrain);
- featurizedTrain.reset();
- Evaluation eval = transferHelper.unfrozenGraph().evaluate(featurizedTest, labelStrings);
- System.out.println(eval.stats(false, false));
- featurizedTest.reset();
- }
- System.out.println("Training complete");
- }
- private DataSetIterator featurize(DataSetIterator dataSetIterator, TransferLearningHelper transferHelper) {
- // featurize ahead of time rather than lazily to avoid issues with multiple workspaces
- List<DataSet> featurizeds = new LinkedList<>();
- while (dataSetIterator.hasNext()) {
- DataSet dataSet = dataSetIterator.next();
- featurizeds.add(transferHelper.featurize(dataSet));
- }
- return new CachingDataSetIterator(
- new ListDataSetIterator<>(featurizeds, dataSetIterator.batch()),
- new InMemoryDataSetCache()
- );
- }
- }
Add Comment
Please, Sign In to add comment