Advertisement
Guest User

Untitled

a guest
Oct 19th, 2018
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.78 KB | None | 0 0
  1.             # Train Generator
  2.             # ---------------
  3.             optimizer_G.zero_grad()
  4.  
  5.             z1 = torch.randn((args.batch_size, args.latent_dim))
  6.             fake_imgs1 = generator(z1)
  7.  
  8.             g_loss = torch.log(1 - discriminator(fake_imgs1)).mean()
  9.             g_loss.backward()
  10.  
  11.             optimizer_G.step()
  12.  
  13.             # Train Discriminator
  14.             # -------------------
  15.             optimizer_D.zero_grad()
  16.  
  17.             z2 = torch.randn((args.batch_size, args.latent_dim))
  18.             fake_imgs2 = generator(z2)
  19.  
  20.             p_x = discriminator(imgs.view(-1,28 * 28))
  21.             p_g = discriminator(fake_imgs2)
  22.  
  23.             d_loss = torch.mean(torch.log(p_x) + torch.log(1 - p_g))
  24.             d_loss.backward()
  25.  
  26.             optimizer_D.step()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement