Advertisement
Guest User

Untitled

a guest
Jun 25th, 2019
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.60 KB | None | 0 0
  1. import numpy as np
  2. from dataclasses import dataclass
  3.  
  4.  
  5. @dataclass(frozen=True)
  6. class Adam:
  7. learning_rate: float = 1.0
  8. eps: float = 1e-8
  9. rho1: float = 0.9
  10. rho2: float = 0.999
  11.  
  12. def calc_update(
  13. self, grad: np.ndarray, m: np.ndarray, v: np.ndarray, t: int = 1
  14. ) -> np.ndarray:
  15. t = max(t, 1)
  16. m[:] = m * self.rho1 + grad * (1 - self.rho1)
  17. v[:] = v * self.rho2 + np.square(grad) * (1 - self.rho2)
  18. m_ = m / (1 - self.rho1 ** t)
  19. v_ = v / (1 - self.rho1 ** t)
  20. update = m_ / np.sqrt(v_ + self.eps)
  21. return -self.learning_rate * update
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement