Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class VAE(nn.Module):
- def __init__(self):
- super(VAE, self).__init__()
- self.fc1 = nn.Linear(784, 400)
- self.fc21 = nn.Linear(400, 20)
- def forward(self, x):
- x = x.view(-1, 784)
- h1 = self.fc1(x)
- h2 = self.fc21(h1)
- return torch.sigmoid(h2)
- # helper function to get sum of List[Tensor]
- def _sum_of_list(tensorlist):
- s = 0
- for t in tensorlist:
- if isinstance(t, torch.Tensor):
- s += t.sum()
- return s
- def clone_inputs(arg):
- input = arg.detach().clone().requires_grad_()
- return input, input
- input_tensor = torch.rand((128, 1, 28, 28), requires_grad=True)
- traced = torch.jit.trace(VAE(), input_tensor)
- recording_inputs, recording_tensors = clone_inputs(input_tensor)
- outputs = traced(recording_inputs)
- l1 = _sum_of_list(outputs)
- grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, allow_unused=True)
- l2 = (_sum_of_list(grads) * l1)
- grads2 = torch.autograd.grad(l2, recording_tensors, create_graph=True, allow_unused=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement