Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torchvision
- import torchvision.transforms as transforms
- from torchvision.utils import save_image
- import os
- # Device configuration
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # Hyperparameters
- latent_dim = 100
- num_epochs = 100
- batch_size = 64
- sample_dir = 'samples'
- os.makedirs(sample_dir, exist_ok=True)
- # MNIST dataset
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=(0.5,), std=(0.5,))
- ])
- mnist = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
- data_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
- # Generator model
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(latent_dim, 256),
- nn.ReLU(),
- nn.Linear(256, 784),
- nn.Tanh()
- )
- def forward(self, z):
- return self.model(z)
- # Discriminator model
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(784, 256),
- nn.LeakyReLU(0.2),
- nn.Linear(256, 1),
- nn.Sigmoid()
- )
- def forward(self, x):
- return self.model(x)
- # Create models
- generator = Generator().to(device)
- discriminator = Discriminator().to(device)
- # Loss and optimizers
- criterion = nn.BCELoss()
- optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
- optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
- # Training loop
- total_steps = len(data_loader)
- for epoch in range(num_epochs):
- for i, (real_images, _) in enumerate(data_loader):
- real_images = real_images.view(-1, 784).to(device)
- batch_size = real_images.shape[0]
- # Train discriminator
- real_labels = torch.ones(batch_size, 1).to(device)
- fake_labels = torch.zeros(batch_size, 1).to(device)
- # Discriminator loss on real images
- outputs = discriminator(real_images)
- d_loss_real = criterion(outputs, real_labels)
- real_score = outputs
- # Generate fake images and compute discriminator loss on fake images
- z = torch.randn(batch_size, latent_dim).to(device)
- fake_images = generator(z)
- outputs = discriminator(fake_images)
- d_loss_fake = criterion(outputs, fake_labels)
- fake_score = outputs
- d_loss = d_loss_real + d_loss_fake
- optimizer_D.zero_grad()
- d_loss.backward()
- optimizer_D.step()
- # Train generator
- z = torch.randn(batch_size, latent_dim).to(device)
- fake_images = generator(z)
- outputs = discriminator(fake_images)
- g_loss = criterion(outputs, real_labels)
- optimizer_G.zero_grad()
- g_loss.backward()
- optimizer_G.step()
- if (i+1) % 100 == 0:
- print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], "
- f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}, "
- f"D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}")
- # Save generated images
- with torch.no_grad():
- fake_images = generator(z).reshape(-1, 1, 28, 28)
- save_image(fake_images, os.path.join(sample_dir, f'fake_images_{epoch+1}.png'))
- # Save models
- torch.save(generator.state_dict(), 'generator.ckpt')
- torch.save(discriminator.state_dict(), 'discriminator.ckpt')
Advertisement
Add Comment
Please, Sign In to add comment