Advertisement
Guest User

Untitled

a guest
Oct 15th, 2020
33
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.05 KB | None | 0 0
  1. import argparse
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. import os
  6. import pickle
  7. import math
  8. from tqdm import tqdm
  9. from data_loader import get_loader
  10. from build_vocab import Vocabulary
  11. from model import EncoderCNN, Decoder
  12. from torch.nn.utils.rnn import pack_padded_sequence
  13. from torchvision import transforms
  14.  
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16.  
  17. def main(args):
  18.     batch_size = 64
  19.     embed_size = 512
  20.     num_heads = 8
  21.     num_layers = 6
  22.     num_workers = 2
  23.     num_epoch = 5
  24.     lr = 1e-3
  25.     load = False
  26.     if not os.path.exists('models/'):
  27.         os.makedirs('models/')
  28.    
  29.     transform = transforms.Compose([
  30.         transforms.RandomCrop(224),
  31.         transforms.RandomHorizontalFlip(),
  32.         transforms.ToTensor(),
  33.         transforms.Normalize((0.485, 0.456, 0.406),
  34.                              (0.229, 0.224, 0.225))])
  35.    
  36.     with open('data/vocab.pkl', 'rb') as f:
  37.         vocab = pickle.load(f)
  38.    
  39.     data_loader = get_loader('data/resized2014', 'data/annotations/captions_train2014.json', vocab,
  40.                              transform, batch_size,
  41.                              shuffle=True, num_workers=num_workers)
  42.  
  43.     encoder = EncoderCNN(embed_size).to(device)
  44.     encoder.fine_tune(False)
  45.     decoder = Decoder(len(vocab), embed_size, num_heads, embed_size, num_layers).to(device)
  46.    
  47.     if(load):
  48.         encoder.load_state_dict(torch.load(os.path.join('models/', 'encoder-{}-{}.ckpt'.format(5, 5000))))
  49.         decoder.load_state_dict(torch.load(os.path.join('models/', 'decoder-{}-{}.ckpt'.format(5, 5000))))
  50.         print("Load Successful")
  51.  
  52.     criterion = nn.CrossEntropyLoss()
  53.     encoder_optim = torch.optim.Adam(encoder.parameters(), lr=lr)
  54.     decoder_optim = torch.optim.Adam(decoder.parameters(), lr=lr)
  55.    
  56.     for epoch in range(num_epoch):
  57.         encoder.train()
  58.         decoder.train()
  59.         for i, (images, captions, lengths) in tqdm(enumerate(data_loader), total=len(data_loader), leave=False):
  60.            
  61.             images = images.to(device)
  62.             captions = captions.to(device)
  63.  
  64.             features = encoder(images)
  65.             cap_input = captions[:, :-1]
  66.             cap_target = captions[:, 1:]
  67.             outputs = decoder(cap_input, features)
  68.             outputs = outputs.permute(1,0,2)
  69.             outputs_shape = outputs.reshape(-1, len(vocab))
  70.             loss = criterion(outputs_shape, cap_target.reshape(-1))
  71.             decoder.zero_grad()
  72.             encoder.zero_grad()
  73.             loss.backward()
  74.             encoder_optim.step()
  75.             decoder_optim.step()
  76.                
  77.             if (i+1) % 1000 == 0:
  78.                 torch.save(decoder.state_dict(), os.path.join(
  79.                     'models/', 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
  80.                 torch.save(encoder.state_dict(), os.path.join(
  81.                     'models/', 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
  82.  
  83. if __name__ == '__main__':
  84.     parser = argparse.ArgumentParser()
  85.     args = parser.parse_args()
  86.     main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement