Rexnime

Findlayertoreplace function

Jun 1st, 2021
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 1.72 KB | None | 0 0
  1. % findLayersToReplace(lgraph) finds the single classification layer and the
  2. % preceding learnable (fully connected or convolutional) layer of the layer
  3. % graph lgraph.
  4.  
  5. function [learnableLayer,classLayer] = findLayersToReplace(lgraph)
  6.  
  7. if ~isa(lgraph,'nnet.cnn.LayerGraph')
  8.     error('Argument must be a LayerGraph object.')
  9. end
  10.  
  11. % Get source, destination, and layer names.
  12. src = string(lgraph.Connections.Source);
  13. dst = string(lgraph.Connections.Destination);
  14. layerNames = string({lgraph.Layers.Name}');
  15.  
  16. % Find the classification layer. The layer graph must have a single
  17. % classification layer.
  18. isClassificationLayer = arrayfun(@(l) ...
  19.     (isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
  20.     lgraph.Layers);
  21.  
  22. if sum(isClassificationLayer) ~= 1
  23.     error('Layer graph must have a single classification layer.')
  24. end
  25. classLayer = lgraph.Layers(isClassificationLayer);
  26.  
  27.  
  28. % Traverse the layer graph in reverse starting from the classification
  29. % layer. If the network branches, throw an error.
  30. currentLayerIdx = find(isClassificationLayer);
  31. while true
  32.    
  33.     if numel(currentLayerIdx) ~= 1
  34.         error('Layer graph must have a single learnable layer preceding the classification layer.')
  35.     end
  36.    
  37.     currentLayerType = class(lgraph.Layers(currentLayerIdx));
  38.     isLearnableLayer = ismember(currentLayerType, ...
  39.         ['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
  40.    
  41.     if isLearnableLayer
  42.         learnableLayer =  lgraph.Layers(currentLayerIdx);
  43.         return
  44.     end
  45.    
  46.     currentDstIdx = find(layerNames(currentLayerIdx) == dst);
  47.     currentLayerIdx = find(src(currentDstIdx) == layerNames);
  48.    
  49. end
  50.  
  51.  
  52.  
  53. end
  54.  
Advertisement
Add Comment
Please, Sign In to add comment