Advertisement
ngnhtrg

Untitled

Oct 2nd, 2023
1,001
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.73 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.models as models
  4.  
  5. class CustomModel(nn.Module):
  6.     def __init__(self, num_classes_digits):
  7.         super(CustomModel, self).__init__()
  8.  
  9.         # Load the pre-trained VGG-16 model
  10.         vgg16 = models.vgg16(pretrained=True)
  11.  
  12.         # Extract the feature layers from VGG-16
  13.         self.features = vgg16.features
  14.  
  15.         # Define custom fully connected layers for digit prediction
  16.         self.fc_layers = nn.Sequential(
  17.             nn.Linear(512 * 7 * 7, 4096),
  18.             nn.ReLU(inplace=True),
  19.             nn.Dropout(),
  20.             nn.Linear(4096, 4096),
  21.             nn.ReLU(inplace=True),
  22.             nn.Dropout()
  23.         )
  24.  
  25.         # Output layers for digit prediction (13 digits)
  26.         self.fc_digits = nn.ModuleList([
  27.             nn.Linear(4096, num_classes_digits) for _ in range(13)
  28.         ])
  29.  
  30.     def forward(self, x):
  31.         x = self.features(x)
  32.         x = x.view(x.size(0), -1)  # Flatten the feature tensor
  33.         x = self.fc_layers(x)
  34.  
  35.         # Output predictions for digits
  36.         digit_logits = [fc(x) for fc in self.fc_digits]
  37.  
  38.         return digit_logits
  39.  
  40.     @staticmethod
  41.     def loss(digits_logits, digits_labels):
  42.         criterion = nn.CrossEntropyLoss()
  43.         digit_losses = [criterion(digits_logits[i], digits_labels[:, i]) for i in range(13)]  # Update to 13 digits
  44.         total_loss = sum(digit_losses)
  45.         return total_loss
  46.  
  47. # Create an instance of the custom model
  48. num_classes_digits = 10  # Define the number of classes for digit prediction
  49. model = CustomModel(num_classes_digits)
  50.  
  51. # Optionally, you can load the pre-trained weights for your custom model
  52. # model.load_state_dict(torch.load('custom_model_weights.pth'))
  53.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement