Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import matplotlib.pyplot as plt
- import numpy as np
- import argparse
- import pickle
- import os
- from torchvision import transforms
- from build_vocab import Vocabulary
- from data_loader import get_loader
- from model import EncoderCNN, Decoder
- from PIL import Image
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def token_sentence(decoder_out, itos):
- tokens = decoder_out
- tokens = tokens.transpose(1, 0)
- tokens = tokens.cpu().numpy()
- results = []
- for instance in tokens:
- result = ' '.join([itos[x] for x in instance])
- results.append(''.join(result.partition('<eos>')[0]))
- return results
- def load_image(image_path, transform=None):
- image = Image.open(image_path).convert('RGB')
- image = image.resize([224, 224], Image.LANCZOS)
- if transform is not None:
- image = transform(image).unsqueeze(0)
- return image
- def main(args):
- batch_size = 64
- embed_size = 512
- num_heads = 8
- num_layers = 6
- num_workers = 2
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.485, 0.456, 0.406),
- (0.229, 0.224, 0.225))])
- with open('data/vocab.pkl', 'rb') as f:
- vocab = pickle.load(f)
- data_loader = get_loader('data/resized2014', 'data/annotations/captions_train2014.json', vocab,
- transform, batch_size,
- shuffle=True, num_workers=num_workers)
- encoder = EncoderCNN(embed_size).to(device)
- encoder.fine_tune(False)
- decoder = Decoder(len(vocab), embed_size, num_heads, embed_size, num_layers).to(device)
- encoder.load_state_dict(torch.load(os.path.join('models/', 'encoder-{}-{}.ckpt'.format(1, 4000))))
- decoder.load_state_dict(torch.load(os.path.join('models/', 'decoder-{}-{}.ckpt'.format(1, 4000))))
- encoder.eval()
- decoder.eval()
- itos = vocab.idx2word
- pred_len = 100
- result_collection = []
- with torch.no_grad():
- for batch_index, (inputs, captions, caplens) in enumerate(data_loader):
- inputs, captions = inputs.cuda(), captions.cuda()
- enc_out = encoder(inputs)
- captions_input = captions[:, :-1]
- captions_target = captions[:, 1:]
- output = decoder.pred(enc_out, pred_len)
- result_caption = token_sentence(output, itos)
- result_collection.extend(result_caption)
- print("Prediction-greedy:", result_collection[1])
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- args = parser.parse_args()
- main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement