Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from torch import LongTensor
- from torch.nn import Embedding, LSTM
- from torch.autograd import Variable
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
- import torch.nn.utils.rnn as rnn_utils
- #Define data to be batched
- a = torch.tensor([[6,1], [9,2], [8,4], [4,14], [1,13], [11,3], [12,1],
- [10,10]])
- b = torch.tensor([[7,2],[3,13],[2,3],[5,6],[13,8],[7,10]])
- c = torch.tensor([[12,2],[5,3],[8,7],[14,10]])
- # put into tensor and then variable
- vectorized_seqs=[a,b,c]
- Variable(LongTensor(vectorized_seqs[0]))
- # Define the LSTM model with embedding layer
- embed = Embedding(15, 4) # embedding_dim = 4, with vocab length 15
- lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5
- # Get sequence lengths
- seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
- # Get padded sequence tensor
- seq_tensor = rnn_utils.pad_sequence(vectorized_seqs, batch_first=True)
- # Input to embedding layer
- embedded_seq_tensor = embed(seq_tensor)
- # Size is now: torch.Size([3, 8, 2, 4])
- # Pack the input
- packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
- # Size is now: torch.Size([18, 2, 4])
- # input to LSTM
- packed_output, (ht, ct) = lstm(packed_input)
- File "/Users/mengwei/Documents/GitHub/packing-unpacking-pytorch-minimal-tutorial/pad_packed_demo.py", line 163, in <module>
- packed_output, (ht, ct) = lstm(packed_input)
- File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
- result = self.forward(*input, **kwargs)
- File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 175, in forward
- self.check_forward_args(input, hx, batch_sizes)
- File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 131, in check_forward_args
- expected_input_dim, input.dim()))
- RuntimeError: input must have 2 dimensions, got 3
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement