Advertisement
Guest User

Untitled

a guest
Mar 8th, 2021
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.08 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5.  
  6.  
  7. class Discriminator(nn.Module):
  8.     def __init__(self, data_dimension, hidden_size, dropout):
  9.         super(Discriminator, self).__init__()
  10.        
  11.         self.network = nn.Sequential()
  12.  
  13.         for layer, layer_output in enumerate(hidden_size):
  14.             if layer == 0:
  15.                 layer_input = data_dimension
  16.             else:
  17.                 layer_input = hidden_size[layer-1]          
  18.            
  19.             self.network.add_module("affine_{}".format(layer+1), nn.Linear(layer_input, layer_output))
  20.             self.network.add_module("dropout_{}".format(layer+1), nn.Dropout(p=dropout))
  21.             self.network.add_module("relu_{}".format(layer+1), nn.LeakyReLU(0.2))
  22.        
  23.         self.network.add_module("affine_{}".format(len(hidden_size)+1), nn.Linear(hidden_size[-1], 1))
  24.         self.network.add_module("output", nn.Sigmoid())
  25.    
  26.  
  27.     def forward(self, x):
  28.         return self.network(x)        
  29.  
  30.  
  31. class Generator(nn.Module):
  32.     def __init__(self, data_dimension, latent_dimension, hidden_size):
  33.         super(Generator, self).__init__()
  34.        
  35.         self.network = nn.Sequential()        
  36.  
  37.         for layer, layer_output in enumerate(hidden_size):
  38.             if layer == 0:
  39.                 layer_input = latent_dimension
  40.             else:
  41.                 layer_input = hidden_size[layer-1]
  42.            
  43.             self.network.add_module("affine_{}".format(layer+1), nn.Linear(layer_input, layer_output))
  44.             self.network.add_module("relu_{}".format(layer+1), nn.LeakyReLU(0.2))
  45.        
  46.         self.network.add_module("affine_{}".format(len(hidden_size)+1), nn.Linear(hidden_size[-1], data_dimension))
  47.         self.network.add_module("output", nn.Tanh())
  48.  
  49.  
  50.     def forward(self, x):
  51.         return self.network(x)
  52.  
  53.  
  54. class GAN(nn.Module):
  55.     def __init__(self, data_dimension, latent_dimension, discriminator_size, generator_size, dropout=0.5):
  56.         super(GAN, self).__init__()
  57.  
  58.         self.data_dimension = data_dimension
  59.  
  60.         if latent_dimension is None:
  61.             self.latent_dimension = data_dimension
  62.         else:
  63.             self.latent_dimension = latent_dimension
  64.        
  65.         self.discriminator = Discriminator(self.data_dimension, discriminator_size, dropout=dropout)
  66.         self.generator = Generator(self.data_dimension, self.latent_dimension, generator_size)
  67.    
  68.  
  69.     def forward(self, x, train_generator=False):
  70.         x = torch.reshape(x, (x.shape[0], np.product(x.shape[1:])))
  71.  
  72.         if train_generator == False:
  73.             return self.discriminator(x)
  74.         else:
  75.             generated_data = self.generator(x)
  76.             discr_gene_data = self.discriminator(generated_data)
  77.  
  78.             return (generated_data, discr_gene_data)
  79.  
  80.  
  81.     def discriminate(self, x):
  82.         discr_out = self.discriminator(x)
  83.        
  84.         return np.where(discr_out >= 0.5, 1, 0)
  85.    
  86.  
  87.     def generate(self, size):
  88.         noise = torch.randn(size, self.latent_dimension)
  89.  
  90.         return self.generator(noise)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement