Advertisement
Guest User

Untitled

a guest
May 4th, 2018
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Octave 8.70 KB | None | 0 0
  1. % This function trains a neural network language model.
  2. function [model] = train(epochs)
  3. % Inputs:
  4. %   epochs: Number of epochs to run.
  5. % Output:
  6. %   model: A struct containing the learned weights and biases and vocabulary.
  7.  
  8. if size(ver('Octave'),1)
  9.   OctaveMode = 1;
  10.   warning('error', 'Octave:broadcast');
  11.   start_time = time;
  12. else
  13.   OctaveMode = 0;
  14.   start_time = clock;
  15. end
  16.  
  17. % SET HYPERPARAMETERS HERE.
  18. batchsize = 100;  % Mini-batch size.
  19. learning_rate = 0.1;  % Learning rate; default = 0.1.
  20. momentum = 0.9;  % Momentum; default = 0.9.
  21. numhid1 = 50;  % Dimensionality of embedding space; default = 50.
  22. numhid2 = 200;  % Number of units in hidden layer; default = 200.
  23. init_wt = 0.01;  % Standard deviation of the normal distribution
  24.                  % which is sampled to get the initial weights; default = 0.01
  25.  
  26. % VARIABLES FOR TRACKING TRAINING PROGRESS.
  27. show_training_CE_after = 100;
  28. show_validation_CE_after = 1000;
  29.  
  30. % LOAD DATA.
  31. [train_input, train_target, valid_input, valid_target, ...
  32.   test_input, test_target, vocab] = load_data(batchsize);
  33. [numwords, batchsize, numbatches] = size(train_input);
  34. vocab_size = size(vocab, 2);
  35.  
  36. % INITIALIZE WEIGHTS AND BIASES.
  37. word_embedding_weights = init_wt * randn(vocab_size, numhid1);
  38. embed_to_hid_weights = init_wt * randn(numwords * numhid1, numhid2);
  39. hid_to_output_weights = init_wt * randn(numhid2, vocab_size);
  40. hid_bias = zeros(numhid2, 1);
  41. output_bias = zeros(vocab_size, 1);
  42.  
  43. word_embedding_weights_delta = zeros(vocab_size, numhid1);
  44. word_embedding_weights_gradient = zeros(vocab_size, numhid1);
  45. embed_to_hid_weights_delta = zeros(numwords * numhid1, numhid2);
  46. hid_to_output_weights_delta = zeros(numhid2, vocab_size);
  47. hid_bias_delta = zeros(numhid2, 1);
  48. output_bias_delta = zeros(vocab_size, 1);
  49. expansion_matrix = eye(vocab_size);
  50. count = 0;
  51. tiny = exp(-30);
  52.  
  53. % TRAIN.
  54. for epoch = 1:epochs
  55.   fprintf(1, 'Epoch %d\n', epoch);
  56.   this_chunk_CE = 0;
  57.   trainset_CE = 0;
  58.   % LOOP OVER MINI-BATCHES.
  59.   for m = 1:numbatches
  60.     input_batch = train_input(:, :, m);
  61.     target_batch = train_target(:, :, m);
  62.  
  63.     % FORWARD PROPAGATE.
  64.     % Compute the state of each layer in the network given the input batch
  65.     % and all weights and biases
  66.     [embedding_layer_state, hidden_layer_state, output_layer_state] = ...
  67.       fprop(input_batch, ...
  68.             word_embedding_weights, embed_to_hid_weights, ...
  69.             hid_to_output_weights, hid_bias, output_bias);
  70.  
  71.     % COMPUTE DERIVATIVE.
  72.     %% Expand the target to a sparse 1-of-K vector.
  73.     expanded_target_batch = expansion_matrix(:, target_batch);
  74.     %% Compute derivative of cross-entropy loss function.
  75.     error_deriv = output_layer_state - expanded_target_batch;
  76.  
  77.     % MEASURE LOSS FUNCTION.
  78.     CE = -sum(sum(...
  79.       expanded_target_batch .* log(output_layer_state + tiny))) / batchsize;
  80.     count =  count + 1;
  81.     this_chunk_CE = this_chunk_CE + (CE - this_chunk_CE) / count;
  82.     trainset_CE = trainset_CE + (CE - trainset_CE) / m;
  83.     fprintf(1, '\rBatch %d Train CE %.3f', m, this_chunk_CE);
  84.     if mod(m, show_training_CE_after) == 0
  85.       fprintf(1, '\n');
  86.       count = 0;
  87.       this_chunk_CE = 0;
  88.     end
  89.     if OctaveMode
  90.       fflush(1);
  91.     end
  92.  
  93.     % BACK PROPAGATE.
  94.     %% OUTPUT LAYER.
  95.     hid_to_output_weights_gradient =  hidden_layer_state * error_deriv';
  96.     output_bias_gradient = sum(error_deriv, 2);
  97.     back_propagated_deriv_1 = (hid_to_output_weights * error_deriv) ...
  98.       .* hidden_layer_state .* (1 - hidden_layer_state);
  99.  
  100.     %% HIDDEN LAYER.
  101.     % FILL IN CODE. Replace the line below by one of the options.
  102.     embed_to_hid_weights_gradient = zeros(numhid1 * numwords, numhid2);
  103.     % Options:
  104.     % (a) embed_to_hid_weights_gradient = back_propagated_deriv_1' * embedding_layer_state;
  105.     % (b) embed_to_hid_weights_gradient = embedding_layer_state * back_propagated_deriv_1';
  106.     % (c) embed_to_hid_weights_gradient = back_propagated_deriv_1;
  107.     % (d) embed_to_hid_weights_gradient = embedding_layer_state;
  108.  
  109.     % FILL IN CODE. Replace the line below by one of the options.
  110.     hid_bias_gradient = zeros(numhid2, 1);
  111.     % Options
  112.     % (a) hid_bias_gradient = sum(back_propagated_deriv_1, 2);
  113.     % (b) hid_bias_gradient = sum(back_propagated_deriv_1, 1);
  114.     % (c) hid_bias_gradient = back_propagated_deriv_1;
  115.     % (d) hid_bias_gradient = back_propagated_deriv_1';
  116.  
  117.     % FILL IN CODE. Replace the line below by one of the options.
  118.     back_propagated_deriv_2 = zeros(numhid2, batchsize);
  119.     % Options
  120.     % (a) back_propagated_deriv_2 = embed_to_hid_weights * back_propagated_deriv_1;
  121.     % (b) back_propagated_deriv_2 = back_propagated_deriv_1 * embed_to_hid_weights;
  122.     % (c) back_propagated_deriv_2 = back_propagated_deriv_1' * embed_to_hid_weights;
  123.     % (d) back_propagated_deriv_2 = back_propagated_deriv_1 * embed_to_hid_weights';
  124.  
  125.     word_embedding_weights_gradient(:) = 0;
  126.     %% EMBEDDING LAYER.
  127.     for w = 1:numwords
  128.        word_embedding_weights_gradient = word_embedding_weights_gradient + ...
  129.          expansion_matrix(:, input_batch(w, :)) * ...
  130.          (back_propagated_deriv_2(1 + (w - 1) * numhid1 : w * numhid1, :)');
  131.     end
  132.    
  133.     % UPDATE WEIGHTS AND BIASES.
  134.     word_embedding_weights_delta = ...
  135.       momentum .* word_embedding_weights_delta + ...
  136.       word_embedding_weights_gradient ./ batchsize;
  137.     word_embedding_weights = word_embedding_weights...
  138.       - learning_rate * word_embedding_weights_delta;
  139.  
  140.     embed_to_hid_weights_delta = ...
  141.       momentum .* embed_to_hid_weights_delta + ...
  142.       embed_to_hid_weights_gradient ./ batchsize;
  143.     embed_to_hid_weights = embed_to_hid_weights...
  144.       - learning_rate * embed_to_hid_weights_delta;
  145.  
  146.     hid_to_output_weights_delta = ...
  147.       momentum .* hid_to_output_weights_delta + ...
  148.       hid_to_output_weights_gradient ./ batchsize;
  149.     hid_to_output_weights = hid_to_output_weights...
  150.       - learning_rate * hid_to_output_weights_delta;
  151.  
  152.     hid_bias_delta = momentum .* hid_bias_delta + ...
  153.       hid_bias_gradient ./ batchsize;
  154.     hid_bias = hid_bias - learning_rate * hid_bias_delta;
  155.  
  156.     output_bias_delta = momentum .* output_bias_delta + ...
  157.       output_bias_gradient ./ batchsize;
  158.     output_bias = output_bias - learning_rate * output_bias_delta;
  159.  
  160.     % VALIDATE.
  161.     if mod(m, show_validation_CE_after) == 0
  162.       fprintf(1, '\rRunning validation ...');
  163.       if OctaveMode
  164.         fflush(1);
  165.       end
  166.       [embedding_layer_state, hidden_layer_state, output_layer_state] = ...
  167.         fprop(valid_input, word_embedding_weights, embed_to_hid_weights,...
  168.               hid_to_output_weights, hid_bias, output_bias);
  169.       datasetsize = size(valid_input, 2);
  170.       expanded_valid_target = expansion_matrix(:, valid_target);
  171.       CE = -sum(sum(...
  172.         expanded_valid_target .* log(output_layer_state + tiny))) /datasetsize;
  173.       fprintf(1, ' Validation CE %.3f\n', CE);
  174.       if OctaveMode
  175.         fflush(1);
  176.       end
  177.     end
  178.   end
  179.   fprintf(1, '\rAverage Training CE %.3f\n', trainset_CE);
  180. end
  181. fprintf(1, 'Finished Training.\n');
  182. if OctaveMode
  183.   fflush(1);
  184. end
  185. fprintf(1, 'Final Training CE %.3f\n', trainset_CE);
  186.  
  187. % EVALUATE ON VALIDATION SET.
  188. fprintf(1, '\rRunning validation ...');
  189. if OctaveMode
  190.   fflush(1);
  191. end
  192. [embedding_layer_state, hidden_layer_state, output_layer_state] = ...
  193.   fprop(valid_input, word_embedding_weights, embed_to_hid_weights,...
  194.         hid_to_output_weights, hid_bias, output_bias);
  195. datasetsize = size(valid_input, 2);
  196. expanded_valid_target = expansion_matrix(:, valid_target);
  197. CE = -sum(sum(...
  198.   expanded_valid_target .* log(output_layer_state + tiny))) / datasetsize;
  199. fprintf(1, '\rFinal Validation CE %.3f\n', CE);
  200. if OctaveMode
  201.   fflush(1);
  202. end
  203.  
  204. % EVALUATE ON TEST SET.
  205. fprintf(1, '\rRunning test ...');
  206. if OctaveMode
  207.   fflush(1);
  208. end
  209. [embedding_layer_state, hidden_layer_state, output_layer_state] = ...
  210.   fprop(test_input, word_embedding_weights, embed_to_hid_weights,...
  211.         hid_to_output_weights, hid_bias, output_bias);
  212. datasetsize = size(test_input, 2);
  213. expanded_test_target = expansion_matrix(:, test_target);
  214. CE = -sum(sum(...
  215.   expanded_test_target .* log(output_layer_state + tiny))) / datasetsize;
  216. fprintf(1, '\rFinal Test CE %.3f\n', CE);
  217. if OctaveMode
  218.   fflush(1);
  219. end
  220.  
  221. model.word_embedding_weights = word_embedding_weights;
  222. model.embed_to_hid_weights = embed_to_hid_weights;
  223. model.hid_to_output_weights = hid_to_output_weights;
  224. model.hid_bias = hid_bias;
  225. model.output_bias = output_bias;
  226. model.vocab = vocab;
  227.  
  228. % In MATLAB replace line below with 'end_time = clock;'
  229. if OctaveMode
  230.   end_time = time;
  231.   diff = end_time - start_time;
  232. else
  233.   end_time = clock;
  234.   diff = etime(end_time, start_time);
  235. end
  236. fprintf(1, 'Training took %.2f seconds\n', diff);
  237. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement