Guest User

Untitled

a guest
Feb 20th, 2018
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.54 KB | None | 0 0
  1. """
  2. PyTorch implementation of a sequence labeler (POS taggger).
  3.  
  4. Basic architecture:
  5. - take words
  6. - run though bidirectional GRU
  7. - predict labels one word at a time (left to right), using a recurrent neural network "decoder"
  8.  
  9. The decoder updates hidden state based on:
  10. - most recent word
  11. - the previous action (aka predicted label).
  12. - the previous hidden state
  13.  
  14. Can it be faster?!?!?!?!?!?
  15. """
  16.  
  17. from __future__ import division
  18. import random
  19. import pickle
  20. import torch
  21. from torch import nn
  22. from torch.autograd import Variable
  23. from torch.nn.parameter import Parameter
  24. import torch.nn.functional as F
  25.  
  26. def reseed(seed=90210):
  27. random.seed(seed)
  28. torch.manual_seed(seed)
  29.  
  30. reseed()
  31.  
  32. class Example(object):
  33. def __init__(self, tokens, labels, n_labels):
  34. self.tokens = tokens
  35. self.labels = labels
  36. self.n_labels = n_labels
  37.  
  38. def minibatch(data, minibatch_size, reshuffle):
  39. if reshuffle:
  40. random.shuffle(data)
  41. for n in range(0, len(data), minibatch_size):
  42. yield data[n:n+minibatch_size]
  43.  
  44. def test_wsj():
  45. print
  46. print('# test on wsj subset')
  47.  
  48. data, n_types, n_labels = pickle.load(open('wsj.pkl', 'rb'))
  49.  
  50. d_emb = 50
  51. d_rnn = 51
  52. d_hid = 52
  53. d_actemb = 5
  54.  
  55. minibatch_size = 5
  56. n_epochs = 10
  57. preprocess_minibatch = True
  58.  
  59. embed_word = nn.Embedding(n_types, d_emb)
  60. gru = nn.GRU(d_emb, d_rnn, bidirectional=True)
  61. embed_action = nn.Embedding(n_labels, d_actemb)
  62. combine_arh = nn.Linear(d_actemb + d_rnn * 2 + d_hid, d_hid)
  63.  
  64. initial_h_tensor = torch.Tensor(1, d_hid)
  65. initial_h_tensor.zero_()
  66. initial_h = Parameter(initial_h_tensor)
  67.  
  68. initial_actemb_tensor = torch.Tensor(1, d_actemb)
  69. initial_actemb_tensor.zero_()
  70. initial_actemb = Parameter(initial_actemb_tensor)
  71.  
  72. policy = nn.Linear(d_hid, n_labels)
  73.  
  74. loss_fn = torch.nn.MSELoss(size_average=False)
  75.  
  76. optimizer = torch.optim.Adam(
  77. list(embed_word.parameters()) +
  78. list(gru.parameters()) +
  79. list(embed_action.parameters()) +
  80. list(combine_arh.parameters()) +
  81. list(policy.parameters()) +
  82. [initial_h, initial_actemb]
  83. , lr=0.01)
  84.  
  85. for _ in range(n_epochs):
  86. total_loss = 0
  87. for batch in minibatch(data, minibatch_size, True):
  88. optimizer.zero_grad()
  89. loss = 0
  90.  
  91. if preprocess_minibatch:
  92. # for efficiency, combine RNN outputs on entire
  93. # minibatch in one go (requires padding with zeros,
  94. # should be masked but isn't right now)
  95. all_tokens = [ex.tokens for ex in batch]
  96. max_length = max(map(len, all_tokens))
  97. all_tokens = [tok + [0] * (max_length - len(tok)) for tok in all_tokens]
  98. all_e = embed_word(Variable(torch.LongTensor(all_tokens), requires_grad=False))
  99. [all_rnn_out, _] = gru(all_e)
  100.  
  101. for ex in batch:
  102. N = len(ex.tokens)
  103. if preprocess_minibatch:
  104. rnn_out = all_rnn_out[0,:,:].view(-1, 1, 2 * d_rnn)
  105. else:
  106. e = embed_word(Variable(torch.LongTensor(ex.tokens), requires_grad=False)).view(N, 1, -1)
  107. [rnn_out, _] = gru(e)
  108. prev_h = initial_h # previous hidden state
  109. actemb = initial_actemb # embedding of previous action
  110. output = []
  111. for t in range(N):
  112. # update hidden state based on most recent
  113. # *predicted* action (not ground truth)
  114. inputs = [actemb, prev_h, rnn_out[t]]
  115. h = F.relu(combine_arh(torch.cat(inputs, 1)))
  116.  
  117. # make prediction
  118. pred_vec = policy(h)
  119. pred_vec = pred_vec.view(-1)
  120. pred = pred_vec.data.numpy().argmin()
  121. output.append(pred)
  122.  
  123. # accumulate loss (squared error against costs)
  124. truth = torch.ones(n_labels)
  125. truth[ex.labels[t]] = 0
  126. loss += loss_fn(pred_vec, Variable(truth, requires_grad=False))
  127.  
  128. # cache hidden state, previous action embedding
  129. prev_h = h
  130. actemb = embed_action(Variable(torch.LongTensor([pred.item()]), requires_grad=False))
  131.  
  132. # print('output=%s, truth=%s' % (output, ex.labels))
  133.  
  134. loss.backward()
  135. total_loss += loss.data.numpy()[0]
  136. optimizer.step()
  137. print(total_loss)
  138.  
  139.  
  140. if __name__ == '__main__':
  141. test_wsj()
Add Comment
Please, Sign In to add comment