Advertisement
Guest User

Untitled

a guest
Oct 15th, 2020
32
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.67 KB | None | 0 0
  1. import torch
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import argparse
  5. import pickle
  6. import os
  7. from torchvision import transforms
  8. from build_vocab import Vocabulary
  9. from data_loader import get_loader
  10. from model import EncoderCNN, Decoder
  11. from PIL import Image
  12.  
  13. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  14. def token_sentence(decoder_out, itos):
  15.     tokens = decoder_out
  16.     tokens = tokens.transpose(1, 0)
  17.     tokens = tokens.cpu().numpy()
  18.     results = []
  19.     for instance in tokens:
  20.         result = ' '.join([itos[x] for x in instance])
  21.         results.append(''.join(result.partition('<eos>')[0]))
  22.     return results
  23.  
  24. def load_image(image_path, transform=None):
  25.     image = Image.open(image_path).convert('RGB')
  26.     image = image.resize([224, 224], Image.LANCZOS)
  27.    
  28.     if transform is not None:
  29.         image = transform(image).unsqueeze(0)
  30.    
  31.     return image
  32.  
  33. def main(args):
  34.     batch_size = 64
  35.     embed_size = 512
  36.     num_heads = 8
  37.     num_layers = 6
  38.     num_workers = 2
  39.    
  40.     transform = transforms.Compose([
  41.         transforms.ToTensor(),
  42.         transforms.Normalize((0.485, 0.456, 0.406),
  43.                              (0.229, 0.224, 0.225))])
  44.    
  45.     with open('data/vocab.pkl', 'rb') as f:
  46.         vocab = pickle.load(f)
  47.  
  48.     data_loader = get_loader('data/resized2014', 'data/annotations/captions_train2014.json', vocab,
  49.                              transform, batch_size,
  50.                              shuffle=True, num_workers=num_workers)
  51.  
  52.     encoder = EncoderCNN(embed_size).to(device)
  53.     encoder.fine_tune(False)
  54.     decoder = Decoder(len(vocab), embed_size, num_heads, embed_size, num_layers).to(device)
  55.  
  56.     encoder.load_state_dict(torch.load(os.path.join('models/', 'encoder-{}-{}.ckpt'.format(1, 4000))))
  57.     decoder.load_state_dict(torch.load(os.path.join('models/', 'decoder-{}-{}.ckpt'.format(1, 4000))))
  58.     encoder.eval()
  59.     decoder.eval()
  60.    
  61.     itos = vocab.idx2word
  62.     pred_len = 100
  63.     result_collection = []
  64.  
  65.     with torch.no_grad():
  66.         for batch_index, (inputs, captions, caplens) in enumerate(data_loader):
  67.             inputs, captions = inputs.cuda(), captions.cuda()
  68.             enc_out = encoder(inputs)
  69.             captions_input = captions[:, :-1]
  70.             captions_target = captions[:, 1:]
  71.             output = decoder.pred(enc_out, pred_len)
  72.             result_caption = token_sentence(output, itos)
  73.             result_collection.extend(result_caption)
  74.        
  75.            
  76.     print("Prediction-greedy:", result_collection[1])
  77.    
  78. if __name__ == '__main__':
  79.     parser = argparse.ArgumentParser()
  80.     args = parser.parse_args()
  81.     main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement