Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- class Discriminator(nn.Module):
- def __init__(self, data_dimension, hidden_size, dropout):
- super(Discriminator, self).__init__()
- self.network = nn.Sequential()
- for layer, layer_output in enumerate(hidden_size):
- if layer == 0:
- layer_input = data_dimension
- else:
- layer_input = hidden_size[layer-1]
- self.network.add_module("affine_{}".format(layer+1), nn.Linear(layer_input, layer_output))
- self.network.add_module("dropout_{}".format(layer+1), nn.Dropout(p=dropout))
- self.network.add_module("relu_{}".format(layer+1), nn.LeakyReLU(0.2))
- self.network.add_module("affine_{}".format(len(hidden_size)+1), nn.Linear(hidden_size[-1], 1))
- self.network.add_module("output", nn.Sigmoid())
- def forward(self, x):
- return self.network(x)
- class Generator(nn.Module):
- def __init__(self, data_dimension, latent_dimension, hidden_size):
- super(Generator, self).__init__()
- self.network = nn.Sequential()
- for layer, layer_output in enumerate(hidden_size):
- if layer == 0:
- layer_input = latent_dimension
- else:
- layer_input = hidden_size[layer-1]
- self.network.add_module("affine_{}".format(layer+1), nn.Linear(layer_input, layer_output))
- self.network.add_module("relu_{}".format(layer+1), nn.LeakyReLU(0.2))
- self.network.add_module("affine_{}".format(len(hidden_size)+1), nn.Linear(hidden_size[-1], data_dimension))
- self.network.add_module("output", nn.Tanh())
- def forward(self, x):
- return self.network(x)
- class GAN(nn.Module):
- def __init__(self, data_dimension, latent_dimension, discriminator_size, generator_size, dropout=0.5):
- super(GAN, self).__init__()
- self.data_dimension = data_dimension
- if latent_dimension is None:
- self.latent_dimension = data_dimension
- else:
- self.latent_dimension = latent_dimension
- self.discriminator = Discriminator(self.data_dimension, discriminator_size, dropout=dropout)
- self.generator = Generator(self.data_dimension, self.latent_dimension, generator_size)
- def forward(self, x, train_generator=False):
- x = torch.reshape(x, (x.shape[0], np.product(x.shape[1:])))
- if train_generator == False:
- return self.discriminator(x)
- else:
- generated_data = self.generator(x)
- discr_gene_data = self.discriminator(generated_data)
- return (generated_data, discr_gene_data)
- def discriminate(self, x):
- discr_out = self.discriminator(x)
- return np.where(discr_out >= 0.5, 1, 0)
- def generate(self, size):
- noise = torch.randn(size, self.latent_dimension)
- return self.generator(noise)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement