Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import torch
- import torch.nn as nn
- import numpy as np
- import os
- import pickle
- import math
- from tqdm import tqdm
- from data_loader import get_loader
- from build_vocab import Vocabulary
- from model import EncoderCNN, Decoder
- from torch.nn.utils.rnn import pack_padded_sequence
- from torchvision import transforms
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def main(args):
- batch_size = 64
- embed_size = 512
- num_heads = 8
- num_layers = 6
- num_workers = 2
- num_epoch = 5
- lr = 1e-3
- load = False
- if not os.path.exists('models/'):
- os.makedirs('models/')
- transform = transforms.Compose([
- transforms.RandomCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize((0.485, 0.456, 0.406),
- (0.229, 0.224, 0.225))])
- with open('data/vocab.pkl', 'rb') as f:
- vocab = pickle.load(f)
- data_loader = get_loader('data/resized2014', 'data/annotations/captions_train2014.json', vocab,
- transform, batch_size,
- shuffle=True, num_workers=num_workers)
- encoder = EncoderCNN(embed_size).to(device)
- encoder.fine_tune(False)
- decoder = Decoder(len(vocab), embed_size, num_heads, embed_size, num_layers).to(device)
- if(load):
- encoder.load_state_dict(torch.load(os.path.join('models/', 'encoder-{}-{}.ckpt'.format(5, 5000))))
- decoder.load_state_dict(torch.load(os.path.join('models/', 'decoder-{}-{}.ckpt'.format(5, 5000))))
- print("Load Successful")
- criterion = nn.CrossEntropyLoss()
- encoder_optim = torch.optim.Adam(encoder.parameters(), lr=lr)
- decoder_optim = torch.optim.Adam(decoder.parameters(), lr=lr)
- for epoch in range(num_epoch):
- encoder.train()
- decoder.train()
- for i, (images, captions, lengths) in tqdm(enumerate(data_loader), total=len(data_loader), leave=False):
- images = images.to(device)
- captions = captions.to(device)
- features = encoder(images)
- cap_input = captions[:, :-1]
- cap_target = captions[:, 1:]
- outputs = decoder(cap_input, features)
- outputs = outputs.permute(1,0,2)
- outputs_shape = outputs.reshape(-1, len(vocab))
- loss = criterion(outputs_shape, cap_target.reshape(-1))
- decoder.zero_grad()
- encoder.zero_grad()
- loss.backward()
- encoder_optim.step()
- decoder_optim.step()
- if (i+1) % 1000 == 0:
- torch.save(decoder.state_dict(), os.path.join(
- 'models/', 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
- torch.save(encoder.state_dict(), os.path.join(
- 'models/', 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- args = parser.parse_args()
- main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement