Advertisement
ROODAY

seq2seq

Nov 16th, 2019
151
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.71 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import numpy as np
  5. import random
  6. import math
  7. import time
  8. from pathlib import Path
  9.  
  10. class Encoder(nn.Module):
  11.   def __init__(self, input_dim, hid_dim, n_layers, dropout):
  12.     super().__init__()
  13.    
  14.     self.hid_dim = hid_dim
  15.     self.n_layers = n_layers
  16.    
  17.     self.rnn = nn.LSTM(input_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
  18.     self.dropout = nn.Dropout(dropout)
  19.      
  20.   def forward(self, src):
  21.     dropped = self.dropout(src)
  22.     #print(dropped)
  23.     #dropped = [src sent len, batch size]
  24.     #print("need: {}, got: {}".format(self.input_dim, dropped.shape))
  25.     outputs, (hidden, cell) = self.rnn(dropped)
  26.    
  27.     #outputs = [src sent len, batch size, hid dim * n directions]
  28.     #hidden = [n layers * n directions, batch size, hid dim]
  29.     #cell = [n layers * n directions, batch size, hid dim]
  30.    
  31.     #outputs are always from the top hidden layer
  32.    
  33.     return hidden, cell
  34.  
  35. class Decoder(nn.Module):
  36.   def __init__(self, output_dim, hid_dim, n_layers, dropout):
  37.     super().__init__()
  38.    
  39.     self.output_dim = output_dim
  40.     self.hid_dim = hid_dim
  41.     self.n_layers = n_layers
  42.    
  43.     self.rnn = nn.LSTM(output_dim, hid_dim, n_layers, dropout = dropout)
  44.     self.out = nn.Linear(hid_dim, output_dim)
  45.     self.dropout = nn.Dropout(dropout)
  46.      
  47.   def forward(self, input, hidden, cell):
  48.    
  49.     #input = [batch size]
  50.     #hidden = [n layers * n directions, batch size, hid dim]
  51.     #cell = [n layers * n directions, batch size, hid dim]
  52.    
  53.     #n directions in the decoder will both always be 1, therefore:
  54.     #hidden = [n layers, batch size, hid dim]
  55.     #context = [n layers, batch size, hid dim]
  56.    
  57.     input = input.unsqueeze(0)
  58.    
  59.     #input = [1, batch size]
  60.    
  61.     dropped = self.dropout(input)
  62.    
  63.     #dropped = [1, batch size]
  64.            
  65.     output, (hidden, cell) = self.rnn(dropped, (hidden, cell))
  66.    
  67.     #output = [sent len, batch size, hid dim * n directions]
  68.     #hidden = [n layers * n directions, batch size, hid dim]
  69.     #cell = [n layers * n directions, batch size, hid dim]
  70.    
  71.     #sent len and n directions will always be 1 in the decoder, therefore:
  72.     #output = [1, batch size, hid dim]
  73.     #hidden = [n layers, batch size, hid dim]
  74.     #cell = [n layers, batch size, hid dim]
  75.    
  76.     prediction = self.out(output.squeeze(0))
  77.    
  78.     #prediction = [batch size, output dim]
  79.    
  80.     return prediction, hidden, cell
  81.  
  82. class Seq2Seq(nn.Module):
  83.   def __init__(self, encoder, decoder, device):
  84.     super().__init__()
  85.    
  86.     self.encoder = encoder
  87.     self.decoder = decoder
  88.     self.device = device
  89.    
  90.     assert encoder.hid_dim == decoder.hid_dim, \
  91.       "Hidden dimensions of encoder and decoder must be equal!"
  92.     assert encoder.n_layers == decoder.n_layers, \
  93.       "Encoder and decoder must have equal number of layers!"
  94.      
  95.   def forward(self, src, trg, teacher_forcing_ratio = 0.5):
  96.    
  97.     #src = [src sent len, batch size]
  98.     #trg = [trg sent len, batch size]
  99.     #teacher_forcing_ratio is probability to use teacher forcing
  100.     #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
  101.    
  102.     batch_size = trg.shape[1]
  103.     max_len = trg.shape[0]
  104.     trg_vocab_size = self.decoder.output_dim
  105.    
  106.     #tensor to store decoder outputs
  107.     outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
  108.    
  109.     #last hidden state of the encoder is used as the initial hidden state of the decoder
  110.     hidden, cell = self.encoder(src)
  111.    
  112.     #first input to the decoder is the <sos> tokens
  113.     input = trg[0,:]
  114.    
  115.     for t in range(1, max_len):
  116.      
  117.       #insert input token embedding, previous hidden and previous cell states
  118.       #receive output tensor (predictions) and new hidden and cell states
  119.       output, hidden, cell = self.decoder(input, hidden, cell)
  120.      
  121.       #place predictions in a tensor holding predictions for each token
  122.       outputs[t] = output
  123.      
  124.       #decide if we are going to use teacher forcing or not
  125.       teacher_force = random.random() < teacher_forcing_ratio
  126.      
  127.       #get the highest predicted token from our predictions
  128.       top1 = output.argmax(1)
  129.      
  130.       #if teacher forcing, use actual next token as next input
  131.       #if not, use predicted token
  132.       input = trg[t] if teacher_force else top1
  133.    
  134.     return outputs
  135.  
  136. def init_weights(m):
  137.   for name, param in m.named_parameters():
  138.     nn.init.uniform_(param.data, -0.08, 0.08)
  139.  
  140. def train(model, iterator, optimizer, criterion, clip):
  141.   model.train()
  142.  
  143.   epoch_loss = 0
  144.  
  145.   for i, batch in enumerate(iterator):
  146.     src = batch['src']
  147.     trg = batch['trg']
  148.    
  149.     optimizer.zero_grad()
  150.    
  151.     output = model(src, trg)
  152.    
  153.     #trg = [trg sent len, batch size]
  154.     #output = [trg sent len, batch size, output dim]
  155.    
  156.     output = output[1:].view(-1, output.shape[-1])
  157.     trg = trg[1:].view(-1)
  158.    
  159.     #trg = [(trg sent len - 1) * batch size]
  160.     #output = [(trg sent len - 1) * batch size, output dim]
  161.    
  162.     loss = criterion(output, trg)
  163.    
  164.     loss.backward()
  165.    
  166.     torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
  167.    
  168.     optimizer.step()
  169.    
  170.     epoch_loss += loss.item()
  171.      
  172.   return epoch_loss / len(iterator)
  173.  
  174. def evaluate(model, iterator, criterion):  
  175.   model.eval()
  176.  
  177.   epoch_loss = 0
  178.  
  179.   with torch.no_grad():  
  180.     for i, batch in enumerate(iterator):
  181.       src = batch['src']
  182.       trg = batch['trg']
  183.  
  184.       output = model(src, trg, 0) #turn off teacher forcing
  185.  
  186.       #trg = [trg sent len, batch size]
  187.       #output = [trg sent len, batch size, output dim]
  188.  
  189.       output = output[1:].view(-1, output.shape[-1])
  190.       trg = trg[1:].view(-1)
  191.  
  192.       #trg = [(trg sent len - 1) * batch size]
  193.       #output = [(trg sent len - 1) * batch size, output dim]
  194.  
  195.       loss = criterion(output, trg)
  196.      
  197.       epoch_loss += loss.item()
  198.      
  199.   return epoch_loss / len(iterator)
  200.  
  201. def epoch_time(start_time, end_time):
  202.   elapsed_time = end_time - start_time
  203.   elapsed_mins = int(elapsed_time / 60)
  204.   elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
  205.   return elapsed_mins, elapsed_secs
  206.  
  207. SEED = 1234
  208. random.seed(SEED)
  209. torch.manual_seed(SEED)
  210. torch.backends.cudnn.deterministic = True
  211.  
  212. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  213.  
  214. data_dir = Path(Path.cwd(), 'data/', 'test')
  215.  
  216. mfccs = [np.load(path) for path in sorted(list(data_dir.rglob('*.mfcc.npy')))]
  217. max_mfcc_len = max([mfcc.shape[0] for mfcc in mfccs])
  218. mfccs = [np.pad(mfcc, [(max_mfcc_len-len(mfcc), 0), (0,0)]) for mfcc in mfccs]
  219.  
  220. keypoints = [np.load(path) for path in sorted(list(data_dir.rglob('*.keypoints.npy')))]
  221. max_kp_len = max([kp.shape[0] for kp in keypoints])
  222. keypoints = [np.pad(kp, [(max_kp_len-len(kp), 0), (0,0), (0,0)]) for kp in keypoints]
  223.  
  224. np.insert(mfccs[0], 0, np.full((20,), -1), axis=0).shape
  225.  
  226. input_sos = np.full((20,), -1)
  227. input_eos = np.full((1,20), np.inf)
  228. output_sos = np.full((1, 17, 3), -1)
  229. output_eos = np.full((1, 17, 3), np.inf)
  230.  
  231. batch_mfccs = torch.tensor([np.append(np.insert(mfcc, 0, np.zeros((20,)), axis=0), input_eos, axis=0) for mfcc in mfccs]).float()
  232. batch_kps   = torch.tensor([np.append(np.insert(kp, 0, output_sos, axis=0), output_eos, axis=0) for kp in keypoints]).float()
  233. it = [{'src': batch_mfccs, 'trg': batch_kps}]
  234.  
  235. for x in it:
  236.  # print(x)
  237.  print('src shape: {}, trg shape: {}'.format(x['src'].shape, x['trg'].shape))
  238. #exit()
  239.  
  240. INPUT_DIM = 20
  241. OUTPUT_DIM = 17
  242. HID_DIM = 512
  243. N_LAYERS = 2
  244. ENC_DROPOUT = 0.5
  245. DEC_DROPOUT = 0.5
  246.  
  247. enc = Encoder(INPUT_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
  248. dec = Decoder(OUTPUT_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
  249. model = Seq2Seq(enc, dec, device).to(device)
  250.        
  251. model.apply(init_weights)
  252.  
  253. optimizer = optim.Adam(model.parameters())
  254. criterion = nn.CrossEntropyLoss()
  255.  
  256. N_EPOCHS = 10
  257. CLIP = 1
  258.  
  259. best_valid_loss = float('inf')
  260.  
  261. for epoch in range(N_EPOCHS):  
  262.   start_time = time.time()
  263.  
  264.   train_loss = train(model, it, optimizer, criterion, CLIP)
  265.   valid_loss = evaluate(model, it, criterion)
  266.  
  267.   end_time = time.time()
  268.  
  269.   epoch_mins, epoch_secs = epoch_time(start_time, end_time)
  270.  
  271.   if valid_loss < best_valid_loss:
  272.     best_valid_loss = valid_loss
  273.     torch.save(model.state_dict(), 'tut1-model.pt')
  274.  
  275.   print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
  276.   print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
  277.   print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
  278.  
  279. model.load_state_dict(torch.load('tut1-model.pt'))
  280. test_loss = evaluate(model, test_iterator, criterion)
  281. print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement