Advertisement
lamiastella

img_to_vec.py

Oct 18th, 2018
241
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.16 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.models as models
  4. import torchvision.transforms as transforms
  5.  
  6.  
  7. class Img2Vec():
  8.  
  9.     def __init__(self, cuda=False, model='resnet-50', layer='default', layer_output_size=512):
  10.         """ Img2Vec
  11.        :param cuda: If set to True, will run forward pass on GPU
  12.        :param model: String name of requested model
  13.        :param layer: String or Int depending on model.  See more docs: https://github.com/christiansafka/img2vec.git
  14.        :param layer_output_size: Int depicting the output size of the requested layer
  15.        """
  16.         self.device = torch.device("cuda" if cuda else "cpu")
  17.         self.layer_output_size = layer_output_size
  18.         self.model, self.extraction_layer = self._get_model_and_layer(model, layer)
  19.  
  20.         self.model = self.model.to(self.device)
  21.  
  22.         self.model.eval()
  23.  
  24.         self.scaler = transforms.Scale((224, 224))
  25.         self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  26.                                               std=[0.229, 0.224, 0.225])
  27.         self.to_tensor = transforms.ToTensor()
  28.  
  29.     def get_vec(self, img, tensor=False):
  30.         """ Get vector embedding from PIL image
  31.        :param img: PIL Image
  32.        :param tensor: If True, get_vec will return a FloatTensor instead of Numpy array
  33.        :returns: Numpy ndarray
  34.        """
  35.         image = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0).to(self.device)
  36.  
  37.         my_embedding = torch.zeros(1, self.layer_output_size, 1, 1)
  38.  
  39.         def copy_data(m, i, o):
  40.             my_embedding.copy_(o.data)
  41.  
  42.         h = self.extraction_layer.register_forward_hook(copy_data)
  43.         h_x = self.model(image)
  44.         h.remove()
  45.  
  46.         if tensor:
  47.             return my_embedding
  48.         else:
  49.             return my_embedding.numpy()[0, :, 0, 0]
  50.  
  51.     def _get_model_and_layer(self, model_name, layer):
  52.         """ Internal method for getting layer from model
  53.        :param model_name: model name such as 'resnet-18'
  54.        :param layer: layer as a string for resnet-18 or int for alexnet
  55.        :returns: pytorch model, selected layer
  56.        """
  57.         if model_name == 'resnet-18':
  58.             model = models.resnet18(pretrained=True)
  59.             if layer == 'default':
  60.                 layer = model._modules.get('avgpool')
  61.                 self.layer_output_size = 512
  62.             else:
  63.                 layer = model._modules.get(layer)
  64.  
  65.             return model, layer
  66.         elif model_name == 'resnet-50':
  67.             model = models.resnet50(pretrained=True)
  68.             if layer == 'default':
  69.                 layer = model._modules.get('avgpool')
  70.                 self.layer_output_size = 512
  71.             else:
  72.                 layer = model._modules.get(layer)
  73.  
  74.         elif model_name == 'alexnet':
  75.             model = models.alexnet(pretrained=True)
  76.             if layer == 'default':
  77.                 layer = model.classifier[-2]
  78.                 self.layer_output_size = 4096
  79.             else:
  80.                 layer = model.classifier[-layer]
  81.  
  82.             return model, layer
  83.  
  84.         else:
  85.             raise KeyError('Model %s was not found' % model_name)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement