Advertisement
Guest User

Untitled

a guest
Jun 17th, 2019
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.71 KB | None | 0 0
  1. class EdnetDown(nn.Module):
  2.     def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
  3.         super(EdnetDown, self).__init__()
  4.         layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
  5.         if normalize:
  6.             layers.append(nn.InstanceNorm2d(out_size))
  7.         layers.append(nn.LeakyReLU(0.2))
  8.         if dropout:
  9.             layers.append(nn.Dropout(dropout))
  10.         self.model = nn.Sequential(*layers)
  11.  
  12.     def forward(self, x):
  13.         return self.model(x)
  14.  
  15.  
  16. class EdnetUp(nn.Module):
  17.     def __init__(self, in_size, out_size, dropout=0.0):
  18.         super(EdnetUp, self).__init__()
  19.         layers = [
  20.             nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
  21.             nn.InstanceNorm2d(out_size),
  22.             nn.ReLU(inplace=True),
  23.         ]
  24.         if dropout:
  25.             layers.append(nn.Dropout(dropout))
  26.  
  27.         self.model = nn.Sequential(*layers)
  28.         weights_init_normal(self.model)
  29.  
  30.     def forward(self, x, skip_input):
  31.         x = self.model(x)
  32.         x = torch.cat((x, skip_input), 1)
  33.  
  34.         return x
  35.  
  36.  
  37. class Ednet(nn.Module):
  38.     def __init__(self, n_channels=3, n_filters=32):
  39.         super(Ednet, self).__init__()
  40.  
  41.         self.down1 = EdnetDown(n_channels, 64, normalize=False)
  42.         self.down2 = EdnetDown(64, 128)
  43.         self.down3 = EdnetDown(128, 256)
  44.         self.down4 = EdnetDown(256, 512, dropout=0.5)
  45.         self.down5 = EdnetDown(512, 512, dropout=0.5)
  46.         self.down6 = EdnetDown(512, 512, dropout=0.5)
  47.         self.down7 = EdnetDown(512, 512, dropout=0.5)
  48.         self.down8 = EdnetDown(512, 512, normalize=False, dropout=0.5)
  49.  
  50.         self.up1 = EdnetUp(512, 512, dropout=0.5)
  51.         self.up2 = EdnetUp(1024, 512, dropout=0.5)
  52.         self.up3 = EdnetUp(1024, 512, dropout=0.5)
  53.         self.up4 = EdnetUp(1024, 512, dropout=0.5)
  54.         self.up5 = EdnetUp(1024, 256)
  55.         self.up6 = EdnetUp(512, 128)
  56.         self.up7 = EdnetUp(256, 64)
  57.  
  58.         self.final = nn.Sequential(
  59.             nn.Upsample(scale_factor=2),
  60.             nn.ZeroPad2d((1, 0, 1, 0)),
  61.             nn.Conv2d(128, 3, 4, padding=1),
  62.             nn.Tanh(),
  63.         )
  64.  
  65.     def forward(self, x):
  66.         # U-Net generator with skip connections from encoder to decoder
  67.         d1 = self.down1(x)
  68.         d2 = self.down2(d1)
  69.         d3 = self.down3(d2)
  70.         d4 = self.down4(d3)
  71.         d5 = self.down5(d4)
  72.         d6 = self.down6(d5)
  73.         d7 = self.down7(d6)
  74.         d8 = self.down8(d7)
  75.         u1 = self.up1(d8, d7)
  76.         u2 = self.up2(u1, d6)
  77.         u3 = self.up3(u2, d5)
  78.         u4 = self.up4(u3, d4)
  79.         u5 = self.up5(u4, d3)
  80.         u6 = self.up6(u5, d2)
  81.         u7 = self.up7(u6, d1)
  82.  
  83.         return self.final(u7)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement