Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torchvision.models as models
- class CustomModel(nn.Module):
- def __init__(self, num_classes_digits):
- super(CustomModel, self).__init__()
- # Load the pre-trained VGG-16 model
- vgg16 = models.vgg16(pretrained=True)
- # Extract the feature layers from VGG-16
- self.features = vgg16.features
- # Define custom fully connected layers for digit prediction
- self.fc_layers = nn.Sequential(
- nn.Linear(512 * 7 * 7, 4096),
- nn.ReLU(inplace=True),
- nn.Dropout(),
- nn.Linear(4096, 4096),
- nn.ReLU(inplace=True),
- nn.Dropout()
- )
- # Output layers for digit prediction (13 digits)
- self.fc_digits = nn.ModuleList([
- nn.Linear(4096, num_classes_digits) for _ in range(13)
- ])
- def forward(self, x):
- x = self.features(x)
- x = x.view(x.size(0), -1) # Flatten the feature tensor
- x = self.fc_layers(x)
- # Output predictions for digits
- digit_logits = [fc(x) for fc in self.fc_digits]
- return digit_logits
- @staticmethod
- def loss(digits_logits, digits_labels):
- criterion = nn.CrossEntropyLoss()
- digit_losses = [criterion(digits_logits[i], digits_labels[:, i]) for i in range(13)] # Update to 13 digits
- total_loss = sum(digit_losses)
- return total_loss
- # Create an instance of the custom model
- num_classes_digits = 10 # Define the number of classes for digit prediction
- model = CustomModel(num_classes_digits)
- # Optionally, you can load the pre-trained weights for your custom model
- # model.load_state_dict(torch.load('custom_model_weights.pth'))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement