brian_dot_casa

image generator example

Aug 25th, 2023
139
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.65 KB | Source Code | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from torchvision.utils import save_image
  6. import os
  7.  
  8. # Device configuration
  9. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  10.  
  11. # Hyperparameters
  12. latent_dim = 100
  13. num_epochs = 100
  14. batch_size = 64
  15. sample_dir = 'samples'
  16. os.makedirs(sample_dir, exist_ok=True)
  17.  
  18. # MNIST dataset
  19. transform = transforms.Compose([
  20.     transforms.ToTensor(),
  21.     transforms.Normalize(mean=(0.5,), std=(0.5,))
  22. ])
  23. mnist = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
  24. data_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
  25.  
  26. # Generator model
  27. class Generator(nn.Module):
  28.     def __init__(self):
  29.         super(Generator, self).__init__()
  30.         self.model = nn.Sequential(
  31.             nn.Linear(latent_dim, 256),
  32.             nn.ReLU(),
  33.             nn.Linear(256, 784),
  34.             nn.Tanh()
  35.         )
  36.  
  37.     def forward(self, z):
  38.         return self.model(z)
  39.  
  40. # Discriminator model
  41. class Discriminator(nn.Module):
  42.     def __init__(self):
  43.         super(Discriminator, self).__init__()
  44.         self.model = nn.Sequential(
  45.             nn.Linear(784, 256),
  46.             nn.LeakyReLU(0.2),
  47.             nn.Linear(256, 1),
  48.             nn.Sigmoid()
  49.         )
  50.  
  51.     def forward(self, x):
  52.         return self.model(x)
  53.  
  54. # Create models
  55. generator = Generator().to(device)
  56. discriminator = Discriminator().to(device)
  57.  
  58. # Loss and optimizers
  59. criterion = nn.BCELoss()
  60. optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
  61. optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
  62.  
  63. # Training loop
  64. total_steps = len(data_loader)
  65. for epoch in range(num_epochs):
  66.     for i, (real_images, _) in enumerate(data_loader):
  67.         real_images = real_images.view(-1, 784).to(device)
  68.         batch_size = real_images.shape[0]
  69.        
  70.         # Train discriminator
  71.         real_labels = torch.ones(batch_size, 1).to(device)
  72.         fake_labels = torch.zeros(batch_size, 1).to(device)
  73.        
  74.         # Discriminator loss on real images
  75.         outputs = discriminator(real_images)
  76.         d_loss_real = criterion(outputs, real_labels)
  77.         real_score = outputs
  78.        
  79.         # Generate fake images and compute discriminator loss on fake images
  80.         z = torch.randn(batch_size, latent_dim).to(device)
  81.         fake_images = generator(z)
  82.         outputs = discriminator(fake_images)
  83.         d_loss_fake = criterion(outputs, fake_labels)
  84.         fake_score = outputs
  85.        
  86.         d_loss = d_loss_real + d_loss_fake
  87.         optimizer_D.zero_grad()
  88.         d_loss.backward()
  89.         optimizer_D.step()
  90.        
  91.         # Train generator
  92.         z = torch.randn(batch_size, latent_dim).to(device)
  93.         fake_images = generator(z)
  94.         outputs = discriminator(fake_images)
  95.         g_loss = criterion(outputs, real_labels)
  96.         optimizer_G.zero_grad()
  97.         g_loss.backward()
  98.         optimizer_G.step()
  99.        
  100.         if (i+1) % 100 == 0:
  101.             print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], "
  102.                   f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}, "
  103.                   f"D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}")
  104.            
  105.     # Save generated images
  106.     with torch.no_grad():
  107.         fake_images = generator(z).reshape(-1, 1, 28, 28)
  108.         save_image(fake_images, os.path.join(sample_dir, f'fake_images_{epoch+1}.png'))
  109.  
  110. # Save models
  111. torch.save(generator.state_dict(), 'generator.ckpt')
  112. torch.save(discriminator.state_dict(), 'discriminator.ckpt')
  113.  
Advertisement
Add Comment
Please, Sign In to add comment