Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def forward(self, x):
- if self.training: #inne zachwania w trainie i teście
- mu_B = torch.mean(x, dim = 0) #obliczenie średniej kroczącej
- tab = (x - mu_B) ** 2
- sigma_B = torch.mean(tab, dim = 0)
- x_hat = (x - mu_B) / torch.sqrt(sigma_B + self.eps)
- y_hat = self.gamma * x_hat + self.beta
- self.mu = (1 - self.momentum) * self.mu + self.momentum * mu_B
- self.sigma = (1 - self.momentum) * self.sigma + self.momentum * torch.sqrt(sigma_B)
- else: #czym się to różni
- x_hat = (x - self.mu) / torch.sqrt(self.sigma + self.eps)
- y_hat = self.gamma * x_hat + self.beta
- return y_hat
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement