Advertisement
asurkis

Untitled

Jan 7th, 2021
647
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.69 KB | None | 0 0
  1. class MyGenerator(nn.Module):
  2.     def __init__(self):
  3.         super().__init__()
  4.         # [ 3, 64, 128, 256, 512, 512, 512, 512, 512 ]
  5.  
  6.         self.downsamplers = [                                                                                                                          # 256 x 256 x   3
  7.             nn.Sequential(nn.Conv2d(   3,  64, 4, stride=2, padding=1, padding_mode='replicate', bias=False ),                        nn.LeakyReLU()), # 128 x 128 x  64
  8.             nn.Sequential(nn.Conv2d(  64, 128, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 128 ), nn.LeakyReLU()), #  64 x  64 x 128
  9.             nn.Sequential(nn.Conv2d( 128, 256, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 256 ), nn.LeakyReLU()), #  32 x  32 x 256
  10.             nn.Sequential(nn.Conv2d( 256, 512, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 512 ), nn.LeakyReLU()), #  16 x  16 x 512
  11.             nn.Sequential(nn.Conv2d( 512, 512, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 512 ), nn.LeakyReLU()), #   8 x   8 x 512
  12.             nn.Sequential(nn.Conv2d( 512, 512, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 512 ), nn.LeakyReLU()), #   4 x   4 x 512
  13.             nn.Sequential(nn.Conv2d( 512, 512, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 512 ), nn.LeakyReLU()), #   2 x   2 x 512
  14.             nn.Sequential(nn.Conv2d( 512, 512, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 512 ), nn.LeakyReLU()), #   1 x   1 x 512
  15.         ]
  16.  
  17.         self.downsampler1 = self.downsamplers[0]
  18.         self.downsampler2 = self.downsamplers[1]
  19.         self.downsampler3 = self.downsamplers[2]
  20.         self.downsampler4 = self.downsamplers[3]
  21.         self.downsampler5 = self.downsamplers[4]
  22.         self.downsampler6 = self.downsamplers[5]
  23.         self.downsampler7 = self.downsamplers[6]
  24.         self.downsampler8 = self.downsamplers[7]
  25.  
  26.         self.upsamplers = [                                                                                                         #                             1 x   1 x  512
  27.             nn.Sequential(nn.ConvTranspose2d(  512, 512, 4, stride=2, padding=1 ), nn.BatchNorm2d( 512 ), nn.Dropout(), nn.ReLU()), #   2 x   2 x (512 + 512) =   2 x   2 x 1024
  28.             nn.Sequential(nn.ConvTranspose2d( 1024, 512, 4, stride=2, padding=1 ), nn.BatchNorm2d( 512 ), nn.Dropout(), nn.ReLU()), #   4 x   4 x (512 + 512) =   4 x   4 x 1024
  29.             nn.Sequential(nn.ConvTranspose2d( 1024, 512, 4, stride=2, padding=1 ), nn.BatchNorm2d( 512 ), nn.Dropout(), nn.ReLU()), #   8 x   8 x (512 + 512) =   8 x   8 x 1024
  30.             nn.Sequential(nn.ConvTranspose2d( 1024, 512, 4, stride=2, padding=1 ), nn.BatchNorm2d( 512 ),               nn.ReLU()), #  16 x  16 x (512 + 512) =  16 x  16 x 1024
  31.             nn.Sequential(nn.ConvTranspose2d( 1024, 256, 4, stride=2, padding=1 ), nn.BatchNorm2d( 256 ),               nn.ReLU()), #  32 x  32 x (256 + 256) =  32 x  32 x  512
  32.             nn.Sequential(nn.ConvTranspose2d(  512, 128, 4, stride=2, padding=1 ), nn.BatchNorm2d( 128 ),               nn.ReLU()), #  64 x  64 x (128 + 128) =  64 x  64 x  256
  33.             nn.Sequential(nn.ConvTranspose2d(  256,  64, 4, stride=2, padding=1 ), nn.BatchNorm2d(  64 ),               nn.ReLU()), # 128 x 128 x ( 64 +  64) = 128 x 128 x  128
  34.             nn.Sequential(nn.ConvTranspose2d(  128,   3, 4, stride=2, padding=1 ), nn.BatchNorm2d(   3 ),               nn.ReLU()), #                           256 x 256 x    3
  35.         ]
  36.  
  37.         self.upsampler1 = self.upsamplers[0]
  38.         self.upsampler2 = self.upsamplers[1]
  39.         self.upsampler3 = self.upsamplers[2]
  40.         self.upsampler4 = self.upsamplers[3]
  41.         self.upsampler5 = self.upsamplers[4]
  42.         self.upsampler6 = self.upsamplers[5]
  43.         self.upsampler7 = self.upsamplers[6]
  44.         self.upsampler8 = self.upsamplers[7]
  45.  
  46.     def forward(self, x):
  47.         skips = []
  48.         t = x
  49.         for layer in self.downsamplers:
  50.             t = layer(t)
  51.             skips.append(t)
  52.  
  53.         t = self.upsamplers[0](skips.pop())
  54.        
  55.         for layer in self.upsamplers[1:]:
  56.             t = torch.cat((t, skips.pop()), dim=1)
  57.             t = layer(t)
  58.  
  59.         return t
  60.  
  61. class MyDiscriminator(nn.Module):
  62.     def __init__(self):
  63.         super().__init__()
  64.  
  65.         self.main = nn.Sequential(                                                                                                      # 256 x 256 x   6
  66.             nn.Conv2d(   6,  64, 4, stride=2, padding=1, padding_mode='replicate', bias=False ),                        nn.LeakyReLU(), # 128 x 128 x  64
  67.             nn.Conv2d(  64, 128, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 128 ), nn.LeakyReLU(), #  64 x  64 x 128
  68.             nn.Conv2d( 128, 256, 4, stride=2, padding=1, padding_mode='replicate', bias=False ), nn.BatchNorm2d( 256 ), nn.LeakyReLU(), #  32 x  32 x 256
  69.             nn.ZeroPad2d(1),                                                                                                            #  34 x  34 x 256
  70.             nn.Conv2d( 256, 512, 4, bias=False ), nn.BatchNorm2d(512), nn.LeakyReLU(),                                                  #  31 x  31 x 512
  71.             nn.ZeroPad2d(1),                                                                                                            #  33 x  33 x 512
  72.             nn.Conv2d( 512,   1, 4, bias=False ),                                                                                       #  30 x  30 x   1
  73.         )
  74.  
  75.     def forward(self, x):
  76.         return self.main(x)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement