Advertisement
Guest User

Untitled

a guest
Mar 8th, 2021
46
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.03 KB | None | 0 0
  1. mnist_gan = GAN(data_dimension=784, latent_dimension=64, discriminator_size=(64, 128, 256), generator_size=(128, 256, 512), dropout=0.3)
  2.  
  3. criterion = nn.BCELoss()
  4.  
  5. discr_optim = optim.Adam(mnist_gan.discriminator.parameters(), lr=1e-3, weight_decay=0)
  6. gene_optim = optim.Adam(mnist_gan.generator.parameters(), lr=1e-3, weight_decay=0)
  7.  
  8. epochs = 50
  9. k = 1
  10.  
  11. device = get_device()
  12. mnist_gan = mnist_gan.to(device)
  13.  
  14. for epoch in range(epochs):
  15.   discr_epoch_loss = 0
  16.   gene_epoch_loss = 0
  17.  
  18.   for _ in range(k):
  19.  
  20.       for real_data, real_labels in train_loader:
  21.  
  22.         real_data = real_data.to(device)
  23.         real_labels = real_labels.float().to(device)
  24.  
  25.         fake_data, fake_labels = generate_noise(batch_size, mnist_gan.data_dimension)
  26.  
  27.         fake_data = fake_data.to(device)
  28.         fake_labels = fake_labels.to(device)
  29.  
  30.         discr_real = mnist_gan(real_data).view(-1)
  31.         discr_fake = mnist_gan(fake_data).view(-1)      
  32.  
  33.         discr_loss_real = criterion(discr_real, real_labels)
  34.         discr_loss_fake = criterion(discr_fake, fake_labels)  
  35.  
  36.         #import pdb; pdb.set_trace()    
  37.  
  38.         total_discr_loss = discr_loss_real + discr_loss_fake
  39.         discr_epoch_loss += total_discr_loss.item()      
  40.  
  41.         discr_optim.zero_grad()
  42.         total_discr_loss.backward()
  43.         discr_optim.step()
  44.  
  45.   for _ in range(len(train_loader)):
  46.  
  47.     noise_data, noise_labels = generate_noise(batch_size, mnist_gan.latent_dimension, as_fake=False)
  48.  
  49.     noise_labels = noise_labels.to(device)
  50.     noise_data = noise_data.to(device)
  51.  
  52.     generated_data, discr_gene = mnist_gan(noise_data, train_generator=True)
  53.  
  54.     total_gene_loss = criterion(discr_gene.view(-1), noise_labels)
  55.     gene_epoch_loss += total_gene_loss.item()
  56.  
  57.     gene_optim.zero_grad()
  58.     total_gene_loss.backward()
  59.     gene_optim.step()
  60.  
  61.   if (epoch+1) % 10 == 0:
  62.     print("\nLoss in epoch ", (epoch+1))
  63.     print("Discriminator: ", discr_epoch_loss / len(train_loader))
  64.     print("Generator: ", gene_epoch_loss / len(train_loader))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement