Advertisement
Guest User

Untitled

a guest
Nov 11th, 2019
161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.04 KB | None | 0 0
  1. function test_example_DBN
  2. load mnist_uint8;
  3.  
  4. train_x = double(train_x) / 255;
  5. test_x = double(test_x) / 255;
  6. train_y = double(train_y);
  7. test_y = double(test_y);
  8.  
  9. %% ex1 train a 100 hidden unit RBM and visualize its weights
  10. rand('state',0)
  11. dbn.sizes = [100];
  12. opts.numepochs = 1;
  13. opts.batchsize = 100;
  14. opts.momentum = 0;
  15. opts.alpha = 1;
  16. dbn = dbnsetup(dbn, train_x, opts);
  17. dbn = dbntrain(dbn, train_x, opts);
  18. figure; visualize(dbn.rbm{1}.W'); % Visualize the RBM weights
  19.  
  20. %% ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
  21. rand('state',0)
  22. %train dbn
  23. dbn.sizes = [100 100];
  24. opts.numepochs = 1;
  25. opts.batchsize = 100;
  26. opts.momentum = 0;
  27. opts.alpha = 1;
  28. dbn = dbnsetup(dbn, train_x, opts);
  29. dbn = dbntrain(dbn, train_x, opts);
  30.  
  31. %unfold dbn to nn
  32. nn = dbnunfoldtonn(dbn, 10);
  33. nn.activation_function = 'sigm';
  34.  
  35. %train nn
  36. opts.numepochs = 1;
  37. opts.batchsize = 100;
  38. nn = nntrain(nn, train_x, train_y, opts);
  39. [er, bad] = nntest(nn, test_x, test_y);
  40.  
  41. assert(er < 0.10, 'Too big error');
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement