Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from torch.nn import functional as F
- from torch import linalg
- import torchvision.models as models
- from transformers import AutoConfig
- class ArcFace(nn.Module):
- def __init__(self, cin, cout, s=8, m=0.5):
- super().__init__()
- self.s = s
- self.sin_m = torch.sin(torch.tensor(m))
- self.cos_m = torch.cos(torch.tensor(m))
- self.cout = cout
- self.fc = nn.Linear(cin, cout, bias=False)
- def forward(self, x, label=None):
- w_L2 = linalg.norm(self.fc.weight.detach(), dim=1, keepdim=True).T
- x_L2 = linalg.norm(x, dim=1, keepdim=True)
- cos = self.fc(x) / (x_L2 * w_L2)
- if label is not None:
- sin_m, cos_m = self.sin_m, self.cos_m
- one_hot = F.one_hot(label, num_classes=self.cout)
- sin = (1 - cos ** 2) ** 0.5
- angle_sum = cos * cos_m - sin * sin_m
- cos = angle_sum * one_hot + cos * (1 - one_hot)
- cos = cos * self.s
- return cos
- config = AutoConfig.from_pretrained("cointegrated/rubert-tiny2")
- bert_out_features = config.hidden_size
- NUM_CATEGORIES = 1800
- NUM_FEATURES = 512
- class ProductVectorizer(nn.Module):
- def __init__(self, num_classes):
- super().__init__()
- self.cnn = models.resnet34(pretrained=True)
- self.bert = AutoModel.from_pretrained(BERT_MODEL)
- self.image_classifier = nn.Linear(in_features=bert_out_features, out_features=NUM_CATEGORIES)
- self.name_classifier = nn.Linear(in_features=bert_out_features, out_features=NUM_CATEGORIES)
- self.attrs_classifier = nn.Linear(in_features=bert_out_features, out_features=NUM_CATEGORIES)
- # Concatenate Embeddings and Batch Normalization
- self.concat_bn = nn.BatchNorm1d(self.cnn.fc.in_features + self.bert.bert_out_features * 2)
- self.final_embedding = nn.Linear(self.cnn.fc.in_features + self.bert.config.hidden_size * 2, NUM_FEATURES)
- self.arcface_layer = ArcFace(NUM_FEATURES, NUM_CATEGORIES)
- # self.classifier = nn.Linear(512, NUM_CATEGORIES)
- # twick modules params
- self.cnn.fc = nn.Linear(self.cnn.fc.in_features, bert_out_features)
- # additional params
- self.image_ce_input = None
- self.name_ce_input = None
- self.attrs_ce_input = None
- def _forward_bert(self, bert_module, x):
- with torch.inference_mode():
- model_output = bert_module(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
- embeddings = model_output.last_hidden_state[:, 0, :]
- embeddings = torch.nn.functional.normalize(embeddings)
- return embeddings[0]
- def forward(self, x, y):
- image_embeddings = self.cnn(x)
- name_embeddings = self._forward_bert(self.bert, x)
- attrs_embeddings = self._forward_bert(self.bert, x)
- self.image_ce_input = self.image_classifier(image_embeddings)
- self.name_ce_input = self.name_classifier(name_embeddings)
- self.attrs_ce_input = self.attrs_classifier(attrs_embeddings)
- concatenated_embeddings = torch.cat((image_embeddings, name_embeddings, attrs_embeddings), dim=1)
- concatenated_embeddings = self.concat_bn(concatenated_embeddings)
- final_embeddings = self.final_embedding(concatenated_embeddings)
- if self.training and y:
- return self.arcface_layer(final_embeddings, y)
- return final_embeddings
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement