Advertisement
Guest User

Untitled

a guest
Sep 21st, 2017
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.80 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
  5.  
  6. x = Variable(torch.randn(4, 20, 5))
  7. x_len = torch.IntTensor([4, 2, 1, 6])
  8. sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
  9. x = pack_padded_sequence(x[indx], sorted_x_len.tolist(), batch_first=True)
  10.  
  11. lstm = nn.LSTM(5, 5, batch_first=True)
  12. h0 = Variable(torch.zeros(1, 4, 5))
  13. c0 = Variable(torch.zeros(1, 4, 5))
  14.  
  15. packed_h, (packed_h_t, packed_c_t) = lstm(x, (h0, c0))
  16. hh, _ = pad_packed_sequence(packed_h, batch_first=True)
  17. print(hh.size()) # Size 4 x 6 x 5 instead of 4 x 20 x 5
  18.  
  19. # restore the sorting
  20. _, inverse_indx = torch.sort(indx, 0)
  21. restore_packed_h_t = packed_h_t[:, inverse_indx]
  22. restore_hh = hh[inverse_indx]
  23. print(restore_packed_h_t)
  24. print(restore_hh)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement