Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- clear vars;
- close all;
- clc;
- digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
- imds = imageDatastore(digitDatasetPath, 'IncludeSubfolders',true,'LabelSource','foldernames');
- alpha_vals = [linspace(0.00001, 0.0001, 20) linspace(0.0005, 0.01, 20) linspace(0.05, 0.1, 10)];
- alpha_count = numel(alpha_vals);
- train_accuracy = zeros(1, alpha_count);
- validation_accuracy = zeros(1, alpha_count);
- figure;
- perm = randperm(10000,30);
- for i = 1:30
- subplot(5,6,i);
- imshow(imds.Files{perm(i)});
- end
- labelCount = countEachLabel(imds)
- img = readimage(imds,1);
- size(img)
- for i = 1: alpha_count
- tic
- fprintf('Iter: %d/%d, alpha: %f.\n',i, alpha_count, alpha_vals(i));
- numTrainFiles = 750;
- [imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
- layers = [
- imageInputLayer([28 28 1])
- convolution2dLayer(3,8,'Padding','same')
- batchNormalizationLayer
- reluLayer
- maxPooling2dLayer(2,'Stride',2)
- convolution2dLayer(3,16,'Padding','same')
- batchNormalizationLayer
- reluLayer
- maxPooling2dLayer(2,'Stride',2)
- convolution2dLayer(3,32,'Padding','same')
- batchNormalizationLayer
- reluLayer
- fullyConnectedLayer(10)
- softmaxLayer
- classificationLayer];
- options = trainingOptions('sgdm','InitialLearnRate',alpha_vals(i), ...
- 'MaxEpochs',5, ...
- 'Shuffle','once', ...
- 'ValidationData',imdsValidation, ...
- 'ValidationFrequency',25, ...
- 'Verbose',false);
- net = trainNetwork(imdsTrain,layers,options);
- YPred_train = classify(net,imdsTrain);
- YTrain = imdsTrain.Labels;
- YPred = classify(net,imdsValidation);
- YValidation = imdsValidation.Labels;
- train_accuracy(i) = sum(YPred_train == YTrain)/numel(YTrain);
- validation_accuracy(i) = sum(YPred == YValidation)/numel(YValidation);
- fprintf('Done. Time: %f s, Train acc: %f, validation acc: %f\n',toc, train_accuracy(i), validation_accuracy(i));
- end
- figure;
- plot(alpha_vals, train_accuracy)
- hold on;
- plot(alpha_vals, validation_accuracy)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement