Advertisement
frolkin28

prod2vec

Dec 9th, 2023 (edited)
475
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.35 KB | None | 0 0
  1. from torch.nn import functional as F
  2. from torch import linalg
  3. import torchvision.models as models
  4. from transformers import AutoConfig
  5.  
  6.  
  7. class ArcFace(nn.Module):
  8.     def __init__(self, cin, cout, s=8, m=0.5):
  9.         super().__init__()
  10.         self.s = s
  11.         self.sin_m = torch.sin(torch.tensor(m))
  12.         self.cos_m = torch.cos(torch.tensor(m))
  13.         self.cout = cout
  14.         self.fc = nn.Linear(cin, cout, bias=False)
  15.  
  16.     def forward(self, x, label=None):
  17.         w_L2 = linalg.norm(self.fc.weight.detach(), dim=1, keepdim=True).T
  18.         x_L2 = linalg.norm(x, dim=1, keepdim=True)
  19.         cos = self.fc(x) / (x_L2 * w_L2)
  20.  
  21.         if label is not None:
  22.             sin_m, cos_m = self.sin_m, self.cos_m
  23.             one_hot = F.one_hot(label, num_classes=self.cout)
  24.             sin = (1 - cos ** 2) ** 0.5
  25.             angle_sum = cos * cos_m - sin * sin_m
  26.             cos = angle_sum * one_hot + cos * (1 - one_hot)
  27.             cos = cos * self.s
  28.  
  29.         return cos
  30.  
  31. config = AutoConfig.from_pretrained("cointegrated/rubert-tiny2")
  32. bert_out_features = config.hidden_size
  33. NUM_CATEGORIES = 1800
  34. NUM_FEATURES = 512
  35.  
  36. class ProductVectorizer(nn.Module):
  37.     def __init__(self, num_classes):
  38.         super().__init__()
  39.         self.cnn = models.resnet34(pretrained=True)
  40.         self.bert = AutoModel.from_pretrained(BERT_MODEL)
  41.  
  42.         self.image_classifier = nn.Linear(in_features=bert_out_features, out_features=NUM_CATEGORIES)
  43.         self.name_classifier = nn.Linear(in_features=bert_out_features, out_features=NUM_CATEGORIES)
  44.         self.attrs_classifier = nn.Linear(in_features=bert_out_features, out_features=NUM_CATEGORIES)
  45.  
  46.         # Concatenate Embeddings and Batch Normalization
  47.         self.concat_bn = nn.BatchNorm1d(self.cnn.fc.in_features + self.bert.bert_out_features * 2)
  48.         self.final_embedding = nn.Linear(self.cnn.fc.in_features + self.bert.config.hidden_size * 2, NUM_FEATURES)
  49.  
  50.         self.arcface_layer = ArcFace(NUM_FEATURES, NUM_CATEGORIES)
  51.         # self.classifier = nn.Linear(512, NUM_CATEGORIES)
  52.  
  53.         # twick modules params
  54.         self.cnn.fc = nn.Linear(self.cnn.fc.in_features, bert_out_features)
  55.  
  56.         # additional params
  57.         self.image_ce_input = None
  58.         self.name_ce_input = None
  59.         self.attrs_ce_input = None
  60.  
  61.     def _forward_bert(self, bert_module, x):
  62.       with torch.inference_mode():
  63.           model_output = bert_module(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
  64.       embeddings = model_output.last_hidden_state[:, 0, :]
  65.       embeddings = torch.nn.functional.normalize(embeddings)
  66.       return embeddings[0]
  67.  
  68.     def forward(self, x, y):
  69.       image_embeddings = self.cnn(x)
  70.       name_embeddings = self._forward_bert(self.bert, x)
  71.       attrs_embeddings = self._forward_bert(self.bert, x)
  72.  
  73.       self.image_ce_input = self.image_classifier(image_embeddings)
  74.       self.name_ce_input = self.name_classifier(name_embeddings)
  75.       self.attrs_ce_input = self.attrs_classifier(attrs_embeddings)
  76.  
  77.       concatenated_embeddings = torch.cat((image_embeddings, name_embeddings, attrs_embeddings), dim=1)
  78.       concatenated_embeddings = self.concat_bn(concatenated_embeddings)
  79.       final_embeddings = self.final_embedding(concatenated_embeddings)
  80.  
  81.       if self.training and y:
  82.         return self.arcface_layer(final_embeddings, y)
  83.  
  84.       return final_embeddings
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement