Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from resnet import resnet50 # 修改后的
- import torch.nn as nn
- import torch.nn.functional as F
- # 伪代码一份
- class FinalModel(nn.Module):
- def __init__(self, num_classes=1000, info_len, k_vector):
- super(FinalModel).__init__()
- self.resnet = resnet50()
- self.avgpool = nn.AvgPool2d(7, stride=1)
- self.fc_info = nn.Linear(info_len, k_vector) # info_len是信息长度,k_vector是想embed到多少维
- self.fc_out = nn.Linear(2048+info_len, num_classes)
- def forword(self, input): # 假定input[0]是图像,input[1]是要加的信息
- x_image = self.resnet(input[0])
- x_image = self.avgpool(x_image)
- x_image = x_image.view(x_image.size(0), -1)
- x_info = self.fc_info(input[1])
- x = torch.cat([x_image, x_info])
- x = self.fc_out(x)
- return x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement