Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
- x = Variable(torch.randn(4, 20, 5))
- x_len = torch.IntTensor([4, 2, 1, 6])
- sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
- x = pack_padded_sequence(x[indx], sorted_x_len.tolist(), batch_first=True)
- lstm = nn.LSTM(5, 5, batch_first=True)
- h0 = Variable(torch.zeros(1, 4, 5))
- c0 = Variable(torch.zeros(1, 4, 5))
- packed_h, (packed_h_t, packed_c_t) = lstm(x, (h0, c0))
- hh, _ = pad_packed_sequence(packed_h, batch_first=True)
- print(hh.size()) # Size 4 x 6 x 5 instead of 4 x 20 x 5
- # restore the sorting
- _, inverse_indx = torch.sort(indx, 0)
- restore_packed_h_t = packed_h_t[:, inverse_indx]
- restore_hh = hh[inverse_indx]
- print(restore_packed_h_t)
- print(restore_hh)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement