• API
• FAQ
• Tools
• Archive
SHARE
TWEET

# Untitled

a guest Jun 15th, 2019 63 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import torch
2. from torch import LongTensor
3. from torch.nn import Embedding, LSTM
6. import torch.nn.utils.rnn as rnn_utils
7.
8. #Define data to be batched
9. a = torch.tensor([[6,1], [9,2], [8,4], [4,14], [1,13], [11,3], [12,1],
10. [10,10]])
11. b = torch.tensor([[7,2],[3,13],[2,3],[5,6],[13,8],[7,10]])
12. c = torch.tensor([[12,2],[5,3],[8,7],[14,10]])
13.
14. # put into tensor and then variable
15. vectorized_seqs=[a,b,c]
16. Variable(LongTensor(vectorized_seqs[0]))
17.
18. # Define the LSTM model with embedding layer
19. embed = Embedding(15, 4) # embedding_dim = 4, with vocab length 15
20. lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5
21.
22. # Get sequence lengths
23. seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
24.
25. # Get padded sequence tensor
27.
28. # Input to embedding layer
29. embedded_seq_tensor = embed(seq_tensor)
30. # Size is now: torch.Size([3, 8, 2, 4])
31.
32. # Pack the input
33. packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
34. # Size is now: torch.Size([18, 2, 4])
35.
36. # input to LSTM
37. packed_output, (ht, ct) = lstm(packed_input)
38.
39. File "/Users/mengwei/Documents/GitHub/packing-unpacking-pytorch-minimal-tutorial/pad_packed_demo.py", line 163, in <module>
40.     packed_output, (ht, ct) = lstm(packed_input)
41.   File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
42.     result = self.forward(*input, **kwargs)
43.   File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 175, in forward
44.     self.check_forward_args(input, hx, batch_sizes)
45.   File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 131, in check_forward_args
46.     expected_input_dim, input.dim()))
47. RuntimeError: input must have 2 dimensions, got 3
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top