Guest User

Untitled

a guest
Jan 20th, 2018
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.48 KB | None | 0 0
  1. """
  2. PyTorch implementation of a sequence labeler (POS taggger).
  3.  
  4. Basic architecture:
  5. - take words
  6. - run though bidirectional GRU
  7. - predict labels one word at a time (left to right), using a recurrent neural network "decoder"
  8.  
  9. The decoder updates hidden state based on:
  10. - most recent word
  11. - the previous action (aka predicted label).
  12. - the previous hidden state
  13.  
  14. Can it be faster?!?!?!?!?!?
  15. """
  16.  
  17. from __future__ import division
  18. import random
  19. import pickle
  20. import torch
  21. from torch import nn
  22. from torch.autograd import Variable
  23. from torch.nn.parameter import Parameter
  24. import torch.nn.functional as F
  25.  
  26. def reseed(seed=90210):
  27. random.seed(seed)
  28. torch.manual_seed(seed)
  29.  
  30. reseed()
  31.  
  32. class Example(object):
  33. def __init__(self, tokens, labels, n_labels):
  34. self.tokens = tokens
  35. self.labels = labels
  36. self.n_labels = n_labels
  37.  
  38. def minibatch(data, minibatch_size, reshuffle):
  39. if reshuffle:
  40. random.shuffle(data)
  41. for n in range(0, len(data), minibatch_size):
  42. yield data[n:n+minibatch_size]
  43.  
  44. def test_wsj():
  45. # print('# test on wsj subset')
  46.  
  47. data, n_types, n_labels = pickle.load(open('wsj.pkl', 'rb'))
  48.  
  49. d_emb = 50
  50. d_rnn = 51
  51. d_hid = 52
  52. d_actemb = 5
  53.  
  54. minibatch_size = 5
  55. n_epochs = 10
  56. preprocess_minibatch = True
  57.  
  58. embed_word = nn.Embedding(n_types, d_emb)
  59. gru = nn.GRU(d_emb, d_rnn, bidirectional=True)
  60. embed_action = nn.Embedding(n_labels, d_actemb)
  61. combine_arh = nn.Linear(d_actemb + d_rnn * 2 + d_hid, d_hid)
  62.  
  63. initial_h_tensor = torch.Tensor(1, d_hid)
  64. initial_h_tensor.zero_()
  65. initial_h = Parameter(initial_h_tensor)
  66.  
  67. initial_actemb_tensor = torch.Tensor(1, d_actemb)
  68. initial_actemb_tensor.zero_()
  69. initial_actemb = Parameter(initial_actemb_tensor)
  70.  
  71. policy = nn.Linear(d_hid, n_labels)
  72.  
  73. loss_fn = torch.nn.MSELoss(size_average=False)
  74.  
  75. optimizer = torch.optim.Adam(
  76. list(embed_word.parameters()) +
  77. list(gru.parameters()) +
  78. list(embed_action.parameters()) +
  79. list(combine_arh.parameters()) +
  80. list(policy.parameters()) +
  81. [initial_h, initial_actemb]
  82. , lr=0.01)
  83.  
  84. for _ in range(n_epochs):
  85. total_loss = 0
  86. for batch in minibatch(data, minibatch_size, True):
  87. optimizer.zero_grad()
  88. loss = 0
  89.  
  90. if preprocess_minibatch:
  91. # for efficiency, combine RNN outputs on entire
  92. # minibatch in one go (requires padding with zeros,
  93. # should be masked but isn't right now)
  94. all_tokens = [ex.tokens for ex in batch]
  95. max_length = max(map(len, all_tokens))
  96. all_tokens = [tok + [0] * (max_length - len(tok)) for tok in all_tokens]
  97. all_e = embed_word(Variable(torch.LongTensor(all_tokens), requires_grad=False))
  98. [all_rnn_out, _] = gru(all_e)
  99.  
  100. for ex in batch:
  101. N = len(ex.tokens)
  102. if preprocess_minibatch:
  103. rnn_out = all_rnn_out[0,:,:].view(-1, 1, 2 * d_rnn)
  104. else:
  105. e = embed_word(Variable(torch.LongTensor(ex.tokens), requires_grad=False)).view(N, 1, -1)
  106. [rnn_out, _] = gru(e)
  107. prev_h = initial_h # previous hidden state
  108. actemb = initial_actemb # embedding of previous action
  109. output = []
  110. for t in range(N):
  111. # update hidden state based on most recent
  112. # *predicted* action (not ground truth)
  113. inputs = [actemb, prev_h, rnn_out[t]]
  114. h = F.relu(combine_arh(torch.cat(inputs, 1)))
  115.  
  116. # make prediction
  117. pred_vec = policy(h)
  118. pred = pred_vec.data.numpy().argmin()
  119. output.append(pred)
  120.  
  121. # accumulate loss (squared error against costs)
  122. truth = torch.ones(n_labels)
  123. truth[ex.labels[t]] = 0
  124. loss += loss_fn(pred_vec, Variable(truth, requires_grad=False))
  125.  
  126. # cache hidden state, previous action embedding
  127. prev_h = h
  128. actemb = embed_action(Variable(torch.LongTensor([pred.item()]), requires_grad=False))
  129.  
  130. # print('output=%s, truth=%s' % (output, ex.labels))
  131.  
  132. loss.backward()
  133. total_loss += loss.data.numpy()[0]
  134. optimizer.step()
  135. # print(total_loss)
  136.  
  137.  
  138. if __name__ == '__main__':
  139. test_wsj()
Add Comment
Please, Sign In to add comment