Advertisement
jack06215

[pytorch] Variable sequence length RNN

Sep 26th, 2020
157
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.82 KB | None | 0 0
  1. import numpy as np
  2. import random
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
  7.  
  8.  
  9. class ToyDataLoader(object):
  10.  
  11.     def __init__(self, batch_size):
  12.         self.batch_size = batch_size
  13.         self.index = 0
  14.         self.dataset_size = 10
  15.  
  16.         # generate 10 random variable length training samples,
  17.         # each time step has 1 feature dimension
  18.         self.X = [
  19.             [[1], [1], [1], [1], [0], [0], [1], [1], [1]],
  20.             [[1], [1], [1], [1]],
  21.             [[0], [0], [1], [1]],
  22.             [[1], [1], [1], [1], [1], [1], [1]],
  23.             [[1], [1]],
  24.             [[0]],
  25.             [[0], [0], [0], [0], [0], [0], [0]],
  26.             [[1]],
  27.             [[0], [1]],
  28.             [[1], [0]]
  29.         ]
  30.  
  31.         # assign labels for the toy traning set
  32.         self.y = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  33.  
  34.     def __len__(self):
  35.         return self.dataset_size // self.batch_size
  36.  
  37.     def __iter__(self):
  38.         return self
  39.  
  40.     def __next__(self):
  41.         if self.index + self.batch_size > self.dataset_size:
  42.             self.index = 0
  43.             raise StopIteration()
  44.         if self.index == 0:  # shufle the dataset
  45.             tmp = list(zip(self.X, self.y))
  46.             random.shuffle(tmp)
  47.             self.X, self.y = zip(*tmp)
  48.             self.y = torch.LongTensor(self.y)
  49.         X = self.X[self.index: self.index + self.batch_size]
  50.         y = self.y[self.index: self.index + self.batch_size]
  51.         self.index += self.batch_size
  52.         return X, y
  53.  
  54.  
  55. class NaiveRNN(nn.Module):
  56.     def __init__(self):
  57.         super(NaiveRNN, self).__init__()
  58.         self.lstm = nn.LSTM(1, 128)
  59.         self.linear = nn.Linear(128, 2)
  60.  
  61.     def forward(self, X):
  62.         '''
  63.        Parameter:
  64.            X: list containing variable length training data
  65.        '''
  66.  
  67.         # get the length of each seq in the batch
  68.         seq_lengths = [len(x) for x in X]
  69.  
  70.         # convert to torch.Tensor
  71.         seq_tensor = [torch.Tensor(seq) for seq in X]
  72.  
  73.         # sort seq_lengths and seq_tensor based on seq_lengths, required by torch.nn.utils.rnn.pad_sequence
  74.         pairs = sorted(zip(seq_lengths, seq_tensor),
  75.                        key=lambda pair: pair[0], reverse=True)
  76.         seq_lengths = torch.LongTensor([pair[0] for pair in pairs])
  77.         seq_tensor = [pair[1] for pair in pairs]
  78.  
  79.         # padded_seq shape: (seq_len, batch_size, feature_size)
  80.         padded_seq = pad_sequence(seq_tensor)
  81.  
  82.         # pack them up
  83.         packed_seq = pack_padded_sequence(padded_seq, seq_lengths.numpy())
  84.  
  85.         # feed to rnn
  86.         packed_output, (ht, ct) = self.lstm(packed_seq)
  87.  
  88.         # linear classification layer
  89.         y_pred = self.linear(ht[-1])
  90.  
  91.         return y_pred
  92.  
  93.  
  94. def main():
  95.     # trainloader = ToyDataLoader(batch_size=2)  # not training at all! !!
  96.     trainloader = ToyDataLoader(batch_size=1) # it converges !!!
  97.  
  98.     model = NaiveRNN()
  99.     criterion = nn.CrossEntropyLoss()
  100.     optimizer = optim.Adadelta(model.parameters(), lr=1.0)
  101.  
  102.     for epoch in range(30):
  103.         # switch to train mode
  104.         model.train()
  105.  
  106.         for i, (X, labels) in enumerate(trainloader):
  107.  
  108.             # compute output
  109.             outputs = model(X)
  110.             loss = criterion(outputs, labels)
  111.  
  112.             # measure accuracy and record loss
  113.             _, predicted = torch.max(outputs, 1)
  114.             accu = (predicted == labels).sum().item() / labels.shape[0]
  115.  
  116.             # compute gradient and do SGD step
  117.             optimizer.zero_grad()
  118.             loss.backward()
  119.  
  120.             optimizer.step()
  121.  
  122.             print('Epoch: [{}][{}/{}]\tLoss {:.4f}\tAccu {:.3f}'.format(
  123.                 epoch, i, len(trainloader), loss, accu))
  124.  
  125.  
  126. if __name__ == '__main__':
  127.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement