Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %% Fine Tuning A Deep Neural Network
- clear; clc;close all;
- imagenet_cnn = load('imagenet-cnn');
- net = imagenet_cnn.convnet;
- net.Layers
- %% Perform net surgery
- layers = net.Layers(1:end-3);
- layers(end+1) = fullyConnectedLayer(12, 'Name', 'fc8_2')
- layers(end+1) = softmaxLayer('Name','prob_2');
- layers(end+1) = classificationLayer('Name','classificationLayer_2')
- %% Setup learning rates for fine-tuning
- % fc 8 - bump up learning rate for last layers
- layers(end-2).WeightLearnRateFactor = 100;
- layers(end-2).WeightL2Factor = 1;
- layers(end-2).BiasLearnRateFactor = 20;
- layers(end-2).BiasL2Factor = 0;
- %% Load Image Data
- rootFolder = fullfile('E:UniversidadTesisMatlab', 'TesisDataBase');
- categories = {'Avion','Banana','Carro','Gato', 'Mango','Perro','Sandia','Tijeras','Silla','Mouse','Calculadora','Arbol'};
- imds = imageDatastore(fullfile(rootFolder, categories), 'LabelSource', 'foldernames');
- tbl = countEachLabel(imds);
- %% Equalize number of images of each class in training set
- minSetCount = min(tbl{:,2}); % determine the smallest amount of images in a category
- % Use splitEachLabel method to trim the set.
- imds = splitEachLabel(imds, minSetCount);
- % Notice that each set now has exactly the same number of images.
- countEachLabel(imds)
- [trainingDS, testDS] = splitEachLabel(imds, 0.7,'randomize');
- % Convert labels to categoricals
- trainingDS.Labels = categorical(trainingDS.Labels);
- trainingDS.ReadFcn = @readFunctionTrain;
- %% Setup test data for validation
- testDS.Labels = categorical(testDS.Labels);
- testDS.ReadFcn = @readFunctionValidation;
- %% Fine-tune the Network
- miniBatchSize = 32; % lower this if your GPU runs out of memory.
- numImages = numel(trainingDS.Files);
- numIterationsPerEpoch = 250;
- maxEpochs = 62;
- lr = 0.01;
- opts = trainingOptions('sgdm', ...
- 'InitialLearnRate', lr,...
- 'LearnRateSchedule', 'none',...
- 'L2Regularization', 0.0005, ...
- 'MaxEpochs', maxEpochs, ...
- 'MiniBatchSize', miniBatchSize);
- net = trainNetwork(trainingDS, layers, opts);
- function Iout = readFunctionTrain(filename)
- % Resize the flowers images to the size required by the network.
- I = imread(filename);
- % Some images may be grayscale. Replicate the image 3 times to
- % create an RGB image.
- if ismatrix(I)
- I = cat(3,I,I,I);
- end
- % Resize the image as required for the CNN.
- Iout = imresize(I, [227 227]);
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement