Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.62 KB | None | 0 0
  1. def forward(self, x):
  2. if self.training:
  3. mu_B = torch.mean(x, dim = 0)
  4. sigma_2_B = torch.mean(x - mu_B, dim = 0)
  5. x_hat = (x - mu_B) / torch.sqrt(sigma_2_B + self.momentum)
  6. y_hat = self.gamma * x_hat + self.beta
  7.  
  8. self.mu = (1 - self.momentum) * self.mu + self.momentum * mu_B
  9. self.sigma = (1 - self.momentum) * self.sigma + self.momentum * torch.sqrt(sigma_2_B)
  10. else:
  11. x_hat = (x - self.mu) / torch.sqrt(self.sigma ** 2 + self.momentum)
  12. y_hat = self.gamma * x_hat + self.beta
  13. return y_hat;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement