Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def forward(self, x):
- if self.training:
- mu_B = torch.mean(x, dim = 0)
- sigma_2_B = torch.mean(x - mu_B, dim = 0)
- x_hat = (x - mu_B) / torch.sqrt(sigma_2_B + self.momentum)
- y_hat = self.gamma * x_hat + self.beta
- self.mu = (1 - self.momentum) * self.mu + self.momentum * mu_B
- self.sigma = (1 - self.momentum) * self.sigma + self.momentum * torch.sqrt(sigma_2_B)
- else:
- x_hat = (x - self.mu) / torch.sqrt(self.sigma ** 2 + self.momentum)
- y_hat = self.gamma * x_hat + self.beta
- return y_hat;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement