SHARE
TWEET

Untitled

a guest Jun 15th, 2019 63 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. from torch import LongTensor
  3. from torch.nn import Embedding, LSTM
  4. from torch.autograd import Variable
  5. from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  6. import torch.nn.utils.rnn as rnn_utils
  7.  
  8. #Define data to be batched
  9. a = torch.tensor([[6,1], [9,2], [8,4], [4,14], [1,13], [11,3], [12,1],
  10. [10,10]])
  11. b = torch.tensor([[7,2],[3,13],[2,3],[5,6],[13,8],[7,10]])
  12. c = torch.tensor([[12,2],[5,3],[8,7],[14,10]])
  13.  
  14. # put into tensor and then variable
  15. vectorized_seqs=[a,b,c]
  16. Variable(LongTensor(vectorized_seqs[0]))
  17.  
  18. # Define the LSTM model with embedding layer
  19. embed = Embedding(15, 4) # embedding_dim = 4, with vocab length 15
  20. lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5
  21.  
  22. # Get sequence lengths
  23. seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
  24.  
  25. # Get padded sequence tensor
  26. seq_tensor = rnn_utils.pad_sequence(vectorized_seqs, batch_first=True)
  27.  
  28. # Input to embedding layer
  29. embedded_seq_tensor = embed(seq_tensor)
  30. # Size is now: torch.Size([3, 8, 2, 4])
  31.  
  32. # Pack the input
  33. packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
  34. # Size is now: torch.Size([18, 2, 4])
  35.  
  36. # input to LSTM
  37. packed_output, (ht, ct) = lstm(packed_input)
  38.      
  39. File "/Users/mengwei/Documents/GitHub/packing-unpacking-pytorch-minimal-tutorial/pad_packed_demo.py", line 163, in <module>
  40.     packed_output, (ht, ct) = lstm(packed_input)
  41.   File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
  42.     result = self.forward(*input, **kwargs)
  43.   File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 175, in forward
  44.     self.check_forward_args(input, hx, batch_sizes)
  45.   File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 131, in check_forward_args
  46.     expected_input_dim, input.dim()))
  47. RuntimeError: input must have 2 dimensions, got 3
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top