Advertisement
Guest User

Untitled

a guest
Dec 8th, 2016
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. %% Fine Tuning A Deep Neural Network
  2. clear; clc;close all;
  3. imagenet_cnn = load('imagenet-cnn');
  4. net = imagenet_cnn.convnet;
  5. net.Layers
  6.  
  7. %% Perform net surgery
  8. layers = net.Layers(1:end-3);
  9. layers(end+1) = fullyConnectedLayer(12, 'Name', 'fc8_2')
  10. layers(end+1) = softmaxLayer('Name','prob_2');
  11. layers(end+1) = classificationLayer('Name','classificationLayer_2')
  12.  
  13. %% Setup learning rates for fine-tuning
  14.  
  15. % fc 8 - bump up learning rate for last layers
  16. layers(end-2).WeightLearnRateFactor = 100;
  17. layers(end-2).WeightL2Factor = 1;
  18. layers(end-2).BiasLearnRateFactor = 20;
  19. layers(end-2).BiasL2Factor = 0;
  20.  
  21. %% Load Image Data
  22.  
  23. rootFolder = fullfile('E:UniversidadTesisMatlab', 'TesisDataBase');
  24. categories = {'Avion','Banana','Carro','Gato', 'Mango','Perro','Sandia','Tijeras','Silla','Mouse','Calculadora','Arbol'};
  25. imds = imageDatastore(fullfile(rootFolder, categories), 'LabelSource', 'foldernames');
  26. tbl = countEachLabel(imds);
  27.  
  28. %% Equalize number of images of each class in training set
  29. minSetCount = min(tbl{:,2}); % determine the smallest amount of images in a category
  30. % Use splitEachLabel method to trim the set.
  31. imds = splitEachLabel(imds, minSetCount);
  32.  
  33. % Notice that each set now has exactly the same number of images.
  34. countEachLabel(imds)
  35. [trainingDS, testDS] = splitEachLabel(imds, 0.7,'randomize');
  36. % Convert labels to categoricals
  37. trainingDS.Labels = categorical(trainingDS.Labels);
  38. trainingDS.ReadFcn = @readFunctionTrain;
  39.  
  40. %% Setup test data for validation
  41. testDS.Labels = categorical(testDS.Labels);
  42. testDS.ReadFcn = @readFunctionValidation;
  43.  
  44. %% Fine-tune the Network
  45.  
  46. miniBatchSize = 32; % lower this if your GPU runs out of memory.
  47. numImages = numel(trainingDS.Files);
  48. numIterationsPerEpoch = 250;
  49. maxEpochs = 62;
  50. lr = 0.01;
  51. opts = trainingOptions('sgdm', ...
  52. 'InitialLearnRate', lr,...
  53. 'LearnRateSchedule', 'none',...
  54. 'L2Regularization', 0.0005, ...
  55. 'MaxEpochs', maxEpochs, ...
  56. 'MiniBatchSize', miniBatchSize);
  57. net = trainNetwork(trainingDS, layers, opts);
  58.  
  59. function Iout = readFunctionTrain(filename)
  60. % Resize the flowers images to the size required by the network.
  61. I = imread(filename);
  62. % Some images may be grayscale. Replicate the image 3 times to
  63. % create an RGB image.
  64. if ismatrix(I)
  65. I = cat(3,I,I,I);
  66. end
  67. % Resize the image as required for the CNN.
  68. Iout = imresize(I, [227 227]);
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement