Advertisement
Guest User

Untitled

a guest
Feb 19th, 2020
149
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.14 KB | None | 0 0
  1. def update_moving(previous, current, alpha):
  2. if previous is None:
  3. return current
  4. return previous * alpha + current * (1 - alpha)
  5.  
  6. class BatchNormalization(Module):
  7. EPS = 1e-3
  8. def __init__(self, alpha = 0.):
  9. super(BatchNormalization, self).__init__()
  10. self.alpha = alpha
  11. self.moving_mean = None
  12. self.moving_variance = None
  13. self._stds = None
  14. self.sqrtVal = None
  15. self.xcentered = None
  16.  
  17. def updateOutput(self, input):
  18. if self.training:
  19. means = np.mean(input, axis=0, keepdims=False)
  20. self._stds = np.mean((input - means[np.newaxis, ...]) ** 2, axis=0)
  21. self.moving_mean = update_moving(self.moving_mean, means, self.alpha)
  22. self.moving_variance = update_moving(self.moving_variance, self._stds * (len(input) / (len(input) - 1)), self.alpha)
  23. else:
  24. means = self.moving_mean
  25. self._stds = self.moving_variance
  26. self.mus = means
  27. self.xcentered = (input - means[np.newaxis, ...])
  28. self.sqrtVal = np.sqrt(self._stds + self.EPS)
  29. self.output = self.xcentered / self.sqrtVal
  30.  
  31. return self.output
  32.  
  33. def updateGradInput(self, input, gradOutput):
  34. if self.training:
  35. batch_size = len(input)
  36. N = batch_size
  37. inp_size = len(input[0])
  38. sqr = self.sqrtVal[np.newaxis, :, np.newaxis]
  39. nominator = (np.identity(batch_size) - (np.ones((batch_size, batch_size)) / N))[:, np.newaxis, :]
  40. first_summand = nominator / sqr
  41. mus = self.mus[np.newaxis, :, np.newaxis]
  42. second_nominator = (input[:,:,np.newaxis] - mus) * (np.transpose(input)[np.newaxis,:,:] - mus) / N
  43. second_summand = second_nominator / sqr ** 3
  44. self.gradInput = np.transpose(np.sum(gradOutput[:, :, np.newaxis] * (first_summand + second_summand), axis=0))
  45. else:
  46. self.gradInput = gradOutput / np.sqrt(self._stds[np.newaxis, ...] + self.EPS)
  47. return self.gradInput
  48.  
  49. def __repr__(self):
  50. return "BatchNormalization"
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement