Advertisement
Guest User

Untitled

a guest
May 21st, 2019
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.02 KB | None | 0 0
  1. class VAE(nn.Module):
  2. def __init__(self):
  3. super(VAE, self).__init__()
  4.  
  5. self.fc1 = nn.Linear(784, 400)
  6. self.fc21 = nn.Linear(400, 20)
  7.  
  8. def forward(self, x):
  9. x = x.view(-1, 784)
  10. h1 = self.fc1(x)
  11. h2 = self.fc21(h1)
  12. return torch.sigmoid(h2)
  13.  
  14.  
  15. # helper function to get sum of List[Tensor]
  16. def _sum_of_list(tensorlist):
  17. s = 0
  18. for t in tensorlist:
  19. if isinstance(t, torch.Tensor):
  20. s += t.sum()
  21. return s
  22.  
  23. def clone_inputs(arg):
  24. input = arg.detach().clone().requires_grad_()
  25. return input, input
  26.  
  27.  
  28. input_tensor = torch.rand((128, 1, 28, 28), requires_grad=True)
  29. traced = torch.jit.trace(VAE(), input_tensor)
  30.  
  31.  
  32. recording_inputs, recording_tensors = clone_inputs(input_tensor)
  33. outputs = traced(recording_inputs)
  34. l1 = _sum_of_list(outputs)
  35. grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, allow_unused=True)
  36.  
  37. l2 = (_sum_of_list(grads) * l1)
  38. grads2 = torch.autograd.grad(l2, recording_tensors, create_graph=True, allow_unused=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement