Guest User

Untitled

a guest
Jun 21st, 2018
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.60 KB | None | 0 0
  1. import numpy as np
  2. import torch
  3. from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
  4.  
  5. def pack_lstm(items, lstm):
  6. N = len(items)
  7. reorder_args = np.argsort([len(it) for it in items])[::-1]
  8. origin_args = torch.from_numpy(np.argsort(reorder_args))
  9. ordered = [items[i] for i in reorder_args]
  10. packed_items = pack_padded_sequence(pad_sequence(ordered, batch_first=True), [len(od) for od in ordered], batch_first=True)
  11. _, (hn, _) = lstm(packed_items)
  12. by_inst_repr = hn.transpose(0,1).reshape(N,-1)
  13. # Now untwizzle
  14. return torch.index_select(by_inst_repr, 0, origin_args)
Add Comment
Please, Sign In to add comment