Advertisement
niharsarangi

RBMGB

Mar 6th, 2013
1,255
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 4.96 KB | None | 0 0
  1. function [model, errors] = rbmGB(X, numhid, varargin)
  2. %Learn RBM with Bernoulli hidden and Gaussian visible units
  3. %based on implementation of Kevin Swersky and Ruslan Salakhutdinov
  4.  
  5. %INPUTS:
  6. %X              ... data. should be continuous
  7. %numhid         ... number of hidden layers
  8. %additional inputs (specified as name value pairs or in struct)
  9. %method         ... CD or SML
  10. %eta            ... learning rate
  11. %momentum       ... momentum for smoothness amd to prevent overfitting
  12. %               ... NOTE: momentum is not recommended with SML
  13. %maxepoch       ... # of epochs: each is a full pass through train data
  14. %avglast        ... how many epochs before maxepoch to start averaging
  15. %               ... before. Procedure suggested for faster convergence by
  16. %               ... Kevin Swersky in his MSc thesis
  17. %penalty        ... weight decay factor
  18. %batchsize      ... The number of training instances per batch
  19. %verbose        ... For printing progress
  20. %anneal         ... Flag. If set true, the penalty is annealed linearly
  21. %               ... through epochs to 10% of its original value
  22.  
  23. %OUTPUTS:
  24. %model.type     ... Type of RBM (i.e. type of its visible and hidden units)
  25. %model.W        ... The weights of the connections
  26. %model.b        ... The biases of the hidden layer
  27. %model.c        ... The biases of the visible layer
  28. %model.top      ... The activity of the top layer, to be used when training
  29. %               ... DBN's
  30. %errors         ... The errors in reconstruction at every epoch
  31.  
  32. %Process options
  33. %if args are just passed through in calls they become cells
  34. if (isstruct(varargin))
  35.     args= prepareArgs(varargin{1});
  36. else
  37.     args= prepareArgs(varargin);
  38. end
  39. [   method        ...
  40.     eta           ...
  41.     momentum      ...
  42.     maxepoch      ...
  43.     avglast       ...
  44.     penalty       ...
  45.     batchsize     ...
  46.     verbose       ...
  47.     anneal        ...
  48.     ] = process_options(args    , ...
  49.     'method'        ,  'CD'     , ...
  50.     'eta'           ,  0.1      , ...
  51.     'momentum'      ,  0.5      , ...
  52.     'maxepoch'      ,  50       , ...
  53.     'avglast'       ,  5        , ...
  54.     'penalty'       , 2e-4      , ...
  55.     'batchsize'     , 100       , ...
  56.     'verbose'       , false     , ...
  57.     'anneal'        , false);
  58. avgstart = maxepoch - avglast;
  59. oldpenalty= penalty;
  60. [N,d]=size(X);
  61.  
  62. if (verbose)
  63.     fprintf('Preprocessing data...\n');
  64. end
  65.  
  66. %Create batches
  67. numcases=N;
  68. numdims=d;
  69. numbatches= ceil(N/batchsize);
  70. groups= repmat(1:numbatches, 1, batchsize);
  71. groups= groups(1:N);
  72. perm=randperm(N);
  73. groups = groups(perm);
  74. for i=1:numbatches
  75.     batchdata{i}= X(groups==i,:);
  76. end
  77.  
  78. %train RBM
  79. W = 0.1*randn(numdims,numhid);
  80. c = zeros(1,numdims);
  81. b = zeros(1,numhid);
  82. ph = zeros(numcases,numhid);
  83. nh = zeros(numcases,numhid);
  84. phstates = zeros(numcases,numhid);
  85. nhstates = zeros(numcases,numhid);
  86. negdata = zeros(numcases,numdims);
  87. negdatastates = zeros(numcases,numdims);
  88. Winc  = zeros(numdims,numhid);
  89. binc = zeros(1,numhid);
  90. cinc = zeros(1,numdims);
  91. Wavg = W;
  92. bavg = b;
  93. cavg = c;
  94. t = 1;
  95. errors=zeros(1,maxepoch);
  96.  
  97. for epoch = 1:maxepoch
  98.    
  99.     errsum=0;
  100.     if (anneal)
  101.         %apply linear weight penalty decay
  102.         penalty= oldpenalty - 0.9*epoch/maxepoch*oldpenalty;
  103.     end
  104.    
  105.     for batch = 1:numbatches
  106.         [numcases numdims]=size(batchdata{batch});
  107.         data = batchdata{batch};
  108.        
  109.         %go up
  110.         ph = logistic(data*W + repmat(b,numcases,1));
  111.         phstates = ph > rand(numcases,numhid);
  112.         if (isequal(method,'SML'))
  113.             if (epoch == 1 && batch == 1)
  114.                 nhstates = phstates;
  115.             end
  116.         elseif (isequal(method,'CD'))
  117.             nhstates = phstates;
  118.         end
  119.        
  120.         %go down
  121.         negdata = nhstates*W' + repmat(c,numcases,1);
  122.         negdatastates = negdata + randn(numcases,numdims);
  123.        
  124.         %go up one more time
  125.         nh = logistic(negdatastates*W + repmat(b,numcases,1));
  126.         nhstates = nh > rand(numcases,numhid);
  127.        
  128.         %update weights and biases
  129.         dW = (data'*ph - negdatastates'*nh);
  130.         dc = sum(data) - sum(negdatastates);
  131.         db = sum(ph) - sum(nh);
  132.         Winc = momentum*Winc + eta*(dW/numcases - penalty*W);
  133.         binc = momentum*binc + eta*(db/numcases);
  134.         cinc = momentum*cinc + eta*(dc/numcases);
  135.         W = W + Winc;
  136.         b = b + binc;
  137.         c = c + cinc;
  138.        
  139.         if (epoch > avgstart)
  140.             %apply averaging
  141.             Wavg = Wavg - (1/t)*(Wavg - W);
  142.             cavg = cavg - (1/t)*(cavg - c);
  143.             bavg = bavg - (1/t)*(bavg - b);
  144.             t = t+1;
  145.         else
  146.             Wavg = W;
  147.             bavg = b;
  148.             cavg = c;
  149.         end
  150.        
  151.         %accumulate reconstruction error
  152.         err= sum(sum( (data-negdata).^2 ));
  153.         errsum = err + errsum;
  154.     end
  155.    
  156.     errors(epoch)=errsum;
  157.     if (verbose)
  158.         fprintf('Ended epoch %i/%i. Reconstruction error is %f\n', ...
  159.             epoch, maxepoch, errsum);
  160.     end
  161. end
  162.  
  163. model.type= 'GB';
  164. model.top= logistic(X*Wavg + repmat(bavg,N,1));
  165. model.W= Wavg;
  166. model.b= bavg;
  167. model.c= cavg;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement