Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Train Generator
- # ---------------
- optimizer_G.zero_grad()
- z1 = torch.randn((args.batch_size, args.latent_dim))
- fake_imgs1 = generator(z1)
- g_loss = torch.log(1 - discriminator(fake_imgs1)).mean()
- g_loss.backward()
- optimizer_G.step()
- # Train Discriminator
- # -------------------
- optimizer_D.zero_grad()
- z2 = torch.randn((args.batch_size, args.latent_dim))
- fake_imgs2 = generator(z2)
- p_x = discriminator(imgs.view(-1,28 * 28))
- p_g = discriminator(fake_imgs2)
- d_loss = torch.mean(torch.log(p_x) + torch.log(1 - p_g))
- d_loss.backward()
- optimizer_D.step()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement