Advertisement
ChrisDunamis

Untitled

May 1st, 2024
496
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.66 KB | None | 0 0
  1. import torch
  2.  
  3. from torch import nn
  4.  
  5. print("torch Version:", torch.__version__)
  6.  
  7.  
  8. class ConvolutionalBlock(nn.Module):
  9.     def __init__(self, input_channels, output_channels, discriminator=False, use_activation=True,
  10.                  use_batch_normalization=True, **kwargs):
  11.  
  12.         super().__init__()
  13.  
  14.         self.use_activation = use_activation
  15.         self.cnn = nn.Conv2d(input_channels, output_channels, **kwargs, bias=not use_batch_normalization)
  16.         self.use_batch_normalization = nn.BatchNorm2d(output_channels) if use_batch_normalization else nn.Identity()
  17.         self.use_activation = (
  18.             nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=output_channels)
  19.         )
  20.  
  21.     def forward(self, x):
  22.         return self.use_activation(self.use_batch_normalization(self.cnn(x)) if self.use_activation else
  23.                                    self.use_batch_normalization(self.cnn(x)))
  24.    
  25.    
  26. class UpsampleBlock(nn.Module):
  27.     def __init__(self, input_channels, scale_factor):
  28.  
  29.         super().__init__()
  30.  
  31.         self.convolutional_layer =  nn.Conv2d(input_channels, (input_channels * scale_factor) ** 2, 2, 3, 1, 1)
  32.         # Where, Input Channels * 4, Height, Width --> Input Channels, Height * 2, Width * 2:
  33.         self.pixel_shuffle = nn.PixelShuffle(scale_factor)
  34.         self.activation = nn.PReLU(num_parameters=input_channels)
  35.    
  36.     def forward(self, x):
  37.         return self.activate(self.pixel_shuffle(self.convolutional_layer(x)))
  38.    
  39.    
  40. class ResidualBlock(nn.Module):
  41.     def __init__(self, input_channels):
  42.  
  43.         super().__init__()
  44.  
  45.         self.block1 = ConvolutionalBlock(input_channels, input_channels, kernel_size=3, stride=1, padding=1)
  46.         self.block2 = ConvolutionalBlock(input_channels, input_channels, kernel_size=3, stride=1, padding=1,
  47.                                          use_activation=False)
  48.    
  49.     def forward(self, x):
  50.         output = self.block1(x)
  51.         output = self.block2(output)
  52.  
  53.         return output + x
  54.    
  55.    
  56. class Generator(nn.Module):
  57.     def __init__(self, input_channels=3, number_of_channels=64, number_of_blocks=16):
  58.  
  59.         super().__init__()
  60.  
  61.         self.initialize = ConvolutionalBlock(input_channels, number_of_channels, kernel_size=9, stride=1, padding=4,
  62.                                              use_batch_normalization=False)
  63.         self.residual_layers = nn.Sequential(*[ResidualBlock(number_of_channels) for _ in range(number_of_blocks)])
  64.         self.convolutional_block = ConvolutionalBlock(input_channels, number_of_channels, kernel_size=9, stride=1,
  65.                                                       padding=4, use_activation=False)
  66.         self.upsamples = nn.Sequential(UpsampleBlock(number_of_channels, 2), UpsampleBlock(number_of_channels, 2))
  67.         self.final_layer = nn.Conv2d(number_of_channels, input_channels, kernel_size=9, stride=1, padding=4)
  68.    
  69.     def forward(self, x):
  70.         initialise = self.initialize(x)
  71.         x = self.residual_layers(initialise)
  72.         x = self.convolutional_block(x) + initialise
  73.         x = self.upsamples(x)
  74.  
  75.         return torch.tanh(self.final_layer(x))
  76.    
  77.    
  78. class Discriminator(nn.Module):
  79.     def __init__(self, input_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
  80.        
  81.         super().__init__()
  82.  
  83.         blocks = []
  84.        
  85.         for idx, feature in enumerate(features):
  86.             blocks.append(
  87.                 ConvolutionalBlock(
  88.                     input_channels, feature, kernel_size=3, stride=1 + (idx % 2), padding=1, discriminator=True,
  89.                     use_activation=True, use_batch_normalization=False if idx == 0 else True
  90.                 )
  91.             )
  92.             input_channels = feature
  93.  
  94.         self.blocks = nn.Sequential(*blocks)
  95.         self.classifier = nn.Sequential(
  96.             nn.AdaptiveAvgPool2d((6, 6)),
  97.             nn.Flatten(),
  98.             nn.Linear((512 * 6 * 6), 1024),
  99.             nn.LeakyReLU(0.2, inplace=True),
  100.             nn.Linear(1024 * 1)
  101.         )
  102.    
  103.     def forward(self, x):
  104.         x = self.blocks(x)
  105.  
  106.         return self.classifier(x)
  107.    
  108.    
  109. def test():
  110.     low_resolution = 24
  111.  
  112.     with torch.cuda.amp.autocast():
  113.         x = torch.randn(5, 3, low_resolution, low_resolution)
  114.         generator = Generator()
  115.         generator_output = generator(x)  # <--- The error points here
  116.         discriminator = Discriminator()
  117.         discriminator_output = discriminator(generator_output)
  118.  
  119.         print("Generator Output Shape:     {}".format(generator_output.shape) + "\n" +
  120.               "Discriminator Output Shape: {}".format(discriminator_output.shape))
  121.  
  122.  
  123. test()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement