Advertisement
Guest User

Приложение 2

a guest
Apr 23rd, 2019
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.82 KB | None | 0 0
  1. from torch import nn
  2. import torch
  3.  
  4. import numpy as np
  5. from tqdm import tqdm_notebook as tqdm
  6. from torch import optim
  7. from torch.optim import lr_scheduler
  8.  
  9. import pandas as pd
  10. from torch.utils.data import DataLoader, Dataset
  11. import torchvision.transforms as trs
  12. import matplotlib.pyplot as plt
  13.  
  14. # Обычная комбинация слоев: Сверточный, Батч норм, Релу
  15. def conv_bn(inp, oup, stride = 1):
  16.     return nn.Sequential(
  17.         nn.Conv2d(inp, oup, 3, stride, padding = 1, bias=False),
  18.         nn.BatchNorm2d(oup),
  19.         nn.ReLU(inplace=True)
  20.                         )
  21.  
  22. # Слой развёртки генома
  23. class Flatten(nn.Module):
  24.     def __init__(self, oup):
  25.         self.oup = oup
  26.  
  27.     def __call__(self, sample):
  28.         return sample.view(sample.shape[0], 1, self.oup, -1)
  29.        
  30.  
  31. # Геномный слой по аналогии с basset
  32. def genome_conv(previous_oup = 3, oup = 3):
  33.     return nn.Sequential(
  34.                                 nn.Conv2d(1, oup, kernel_size = (previous_oup, 3)),
  35.                                 nn.ReLU(inplace = True),
  36.                                 Flatten(oup=oup),
  37.                                 nn.MaxPool2d(kernel_size= (1, 2))
  38.                         )
  39.  
  40. # Если параметр mode = 'Genome' после первого слоя дальше идут "Геномные слои" genome_conv, если mode = Traditional, то conv_bn
  41. class GenomeNet(nn.Module):
  42.     def __init__(self, mode = 'Genome'):
  43.         super(GenomeNet, self).__init__()
  44.        
  45.         self.mode = mode
  46.        
  47.         self.conv = nn.Conv2d(1, 3, kernel_size = (4,11))
  48.         self.relu = nn.ReLU(inplace = True)
  49.         self.pooling_initial = nn.MaxPool2d(kernel_size= (1, 4))
  50.         self.pooling_afterwards = nn.MaxPool2d(kernel_size= (1, 2))
  51.        
  52.        
  53.         if mode == 'Traditional':
  54.         ### option 1
  55.             self.ConvBnRelu1 = conv_bn(1, 32, stride = 1)
  56.             self.ConvBnRelu2 = conv_bn(32, 64, stride = 1)
  57.             self.ConvBnRelu3 = conv_bn(64, 128, stride = 1)
  58.             self.ConvBnRelu4 = conv_bn(128, 260, stride = 1)
  59.             self.GAP = nn.AdaptiveAvgPool2d(1)
  60.  
  61.         elif mode == 'Genome':
  62.         ### option 2
  63.             self.GenomeConv1 = genome_conv(previous_oup = 3, oup = 8)
  64.             self.GenomeConv2 = genome_conv(previous_oup = 8, oup = 16)
  65.             self.GenomeConv3 = genome_conv(previous_oup = 16, oup = 16)
  66.             self.GenomeConv4 = genome_conv(previous_oup = 16, oup = 32)
  67.             self.GenomeConv5 = genome_conv(previous_oup = 32, oup = 52)
  68.        
  69.        
  70.         # Полносвязные слои
  71.         self.FC1 = nn.Linear(260, 164)
  72.         self.FC2 = nn.Linear(164, 164)
  73.        
  74.         self.FinalFC = nn.Linear(164, 1)
  75.         self.Sigmoid = nn.Sigmoid()
  76.        
  77.     def forward(self, x):
  78.        
  79.         x = self.conv(x).reshape(x.shape[0], 1, 3, -1)
  80.        
  81.         x = self.relu(x)
  82.         x = self.pooling_initial(x) #100 x 1 x 3 x 122
  83.        
  84.         if self.mode == 'Traditional':
  85.             x = self.pooling_afterwards(self.ConvBnRelu1(x))
  86.             x = self.pooling_afterwards(self.ConvBnRelu2(x))
  87.             x = self.pooling_afterwards(self.ConvBnRelu3(x))
  88.             x = self.pooling_afterwards(self.ConvBnRelu4(x))
  89.            
  90.             x = self.GAP(x).reshape(x.shape[0], -1)
  91.                
  92.         elif self.mode == 'Genome':
  93.             x = self.GenomeConv1(x)
  94.             x = self.GenomeConv2(x)
  95.             x = self.GenomeConv3(x)
  96.             x = self.GenomeConv4(x)
  97.             x = self.GenomeConv5(x)
  98.            
  99.             x = x.reshape(x.shape[0], -1)
  100.            
  101.         x = self.relu(self.FC1(x))
  102.         x = self.relu(self.FC2(x))
  103.        
  104.         x = self.Sigmoid(self.FinalFC(x))
  105.        
  106.         return x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement