Advertisement
Guest User

Untitled

a guest
Aug 21st, 2019
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.84 KB | None | 0 0
  1. import torch
  2. from resnet import resnet50 # 修改后的
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5.  
  6. # 伪代码一份
  7.  
  8. class FinalModel(nn.Module):
  9. def __init__(self, num_classes=1000, info_len, k_vector):
  10. super(FinalModel).__init__()
  11. self.resnet = resnet50()
  12. self.avgpool = nn.AvgPool2d(7, stride=1)
  13. self.fc_info = nn.Linear(info_len, k_vector) # info_len是信息长度,k_vector是想embed到多少维
  14. self.fc_out = nn.Linear(2048+info_len, num_classes)
  15.  
  16. def forword(self, input): # 假定input[0]是图像,input[1]是要加的信息
  17. x_image = self.resnet(input[0])
  18. x_image = self.avgpool(x_image)
  19. x_image = x_image.view(x_image.size(0), -1)
  20.  
  21. x_info = self.fc_info(input[1])
  22. x = torch.cat([x_image, x_info])
  23.  
  24. x = self.fc_out(x)
  25. return x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement