Advertisement
Guest User

Untitled

a guest
Jan 22nd, 2020
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 2.25 KB | None | 0 0
  1. clear vars;
  2. close all;
  3. clc;
  4.  
  5. digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
  6. imds = imageDatastore(digitDatasetPath, 'IncludeSubfolders',true,'LabelSource','foldernames');
  7.  
  8. alpha_vals = [linspace(0.00001, 0.0001, 20) linspace(0.0005, 0.01, 20) linspace(0.05, 0.1, 10)];
  9. alpha_count = numel(alpha_vals);
  10. train_accuracy = zeros(1, alpha_count);
  11. validation_accuracy = zeros(1, alpha_count);
  12.  
  13. figure;
  14. perm = randperm(10000,30);
  15. for i = 1:30
  16.     subplot(5,6,i);
  17.     imshow(imds.Files{perm(i)});
  18. end
  19.  
  20. labelCount = countEachLabel(imds)
  21.  
  22. img = readimage(imds,1);
  23. size(img)
  24.  
  25. for i = 1: alpha_count
  26.    
  27.     tic
  28.     fprintf('Iter: %d/%d, alpha: %f.\n',i, alpha_count, alpha_vals(i));
  29.    
  30.     numTrainFiles = 750;
  31.     [imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
  32.    
  33.     layers = [
  34.         imageInputLayer([28 28 1])
  35.        
  36.         convolution2dLayer(3,8,'Padding','same')
  37.         batchNormalizationLayer
  38.         reluLayer
  39.        
  40.         maxPooling2dLayer(2,'Stride',2)
  41.        
  42.         convolution2dLayer(3,16,'Padding','same')
  43.         batchNormalizationLayer
  44.         reluLayer
  45.        
  46.         maxPooling2dLayer(2,'Stride',2)
  47.        
  48.         convolution2dLayer(3,32,'Padding','same')
  49.         batchNormalizationLayer
  50.         reluLayer
  51.        
  52.         fullyConnectedLayer(10)
  53.         softmaxLayer
  54.         classificationLayer];
  55.    
  56.     options = trainingOptions('sgdm','InitialLearnRate',alpha_vals(i), ...
  57.         'MaxEpochs',5, ...
  58.         'Shuffle','once', ...
  59.         'ValidationData',imdsValidation, ...
  60.         'ValidationFrequency',25, ...
  61.         'Verbose',false);
  62.    
  63.     net = trainNetwork(imdsTrain,layers,options);
  64.    
  65.     YPred_train  = classify(net,imdsTrain);
  66.     YTrain = imdsTrain.Labels;
  67.    
  68.     YPred = classify(net,imdsValidation);
  69.     YValidation = imdsValidation.Labels;
  70.    
  71.     train_accuracy(i) = sum(YPred_train == YTrain)/numel(YTrain);
  72.     validation_accuracy(i) = sum(YPred == YValidation)/numel(YValidation);
  73.     fprintf('Done. Time: %f s, Train acc: %f, validation acc: %f\n',toc, train_accuracy(i), validation_accuracy(i));
  74. end
  75.  
  76. figure;
  77. plot(alpha_vals, train_accuracy)
  78. hold on;
  79. plot(alpha_vals, validation_accuracy)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement