Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from dataclasses import dataclass
- @dataclass(frozen=True)
- class Adam:
- learning_rate: float = 1.0
- eps: float = 1e-8
- rho1: float = 0.9
- rho2: float = 0.999
- def calc_update(
- self, grad: np.ndarray, m: np.ndarray, v: np.ndarray, t: int = 1
- ) -> np.ndarray:
- t = max(t, 1)
- m[:] = m * self.rho1 + grad * (1 - self.rho1)
- v[:] = v * self.rho2 + np.square(grad) * (1 - self.rho2)
- m_ = m / (1 - self.rho1 ** t)
- v_ = v / (1 - self.rho1 ** t)
- update = m_ / np.sqrt(v_ + self.eps)
- return -self.learning_rate * update
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement