Guest User

Untitled

a guest
Aug 13th, 2025
13
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 18.14 KB | None | 0 0
  1. import math
  2. from typing import List, Optional
  3.  
  4. from torch.optim import Optimizer
  5. from torch.optim.lr_scheduler import LRScheduler
  6.  
  7.  
  8. class CosineAnnealingWarmupRestarts(LRScheduler):
  9. r"""CosineAnnealingWarmupRestarts.
  10.  
  11. :param optimizer: Optimizer. wrapped optimizer instance.
  12. :param first_cycle_steps: int. first cycle step size.
  13. :param cycle_mult: float. cycle steps magnification.
  14. :param max_lr: float.
  15. :param min_lr: float.
  16. :param warmup_steps: int. number of warmup steps.
  17. :param gamma: float. decrease rate of lr by cycle.
  18. :param last_epoch: int. step size of the current cycle.
  19. """
  20.  
  21. def __init__(
  22. self,
  23. optimizer: Optimizer,
  24. first_cycle_steps: int,
  25. cycle_mult: float = 1.0,
  26. max_lr: float = 1e-4,
  27. min_lr: float = 1e-6,
  28. warmup_steps: int = 0,
  29. gamma: float = 0.9,
  30. last_epoch: int = -1,
  31. ):
  32. if warmup_steps >= first_cycle_steps:
  33. raise ValueError(
  34. f'[-] warmup_steps must be smaller than first_cycle_steps. {warmup_steps} < {first_cycle_steps}'
  35. )
  36.  
  37. self.first_cycle_steps = first_cycle_steps
  38. self.cycle_mult = cycle_mult
  39. self.base_max_lr = max_lr
  40. self.max_lr = max_lr
  41. self.min_lr = min_lr
  42. self.warmup_steps = warmup_steps
  43. self.gamma = gamma
  44. self.cur_cycle_steps = first_cycle_steps
  45. self.step_in_cycle = last_epoch
  46. self.last_epoch = last_epoch
  47.  
  48. self.cycle: int = 0
  49. self.base_lrs: List[float] = []
  50.  
  51. super().__init__(optimizer, last_epoch)
  52.  
  53. self.init_lr()
  54.  
  55. def init_lr(self) -> None:
  56. self.base_lrs = []
  57. for param_group in self.optimizer.param_groups:
  58. param_group['lr'] = self.min_lr
  59. self.base_lrs.append(self.min_lr)
  60.  
  61. def get_lr(self) -> List[float]:
  62. if self.step_in_cycle == -1:
  63. return self.base_lrs
  64.  
  65. if self.step_in_cycle < self.warmup_steps:
  66. return [
  67. (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs
  68. ]
  69.  
  70. return [
  71. base_lr
  72. + (self.max_lr - base_lr)
  73. * (
  74. 1
  75. + math.cos(
  76. math.pi * (self.step_in_cycle - self.warmup_steps) / (self.cur_cycle_steps - self.warmup_steps)
  77. )
  78. )
  79. / 2.0
  80. for base_lr in self.base_lrs
  81. ]
  82.  
  83. def step(self, epoch: Optional[int] = None):
  84. if epoch is None:
  85. epoch = self.last_epoch + 1
  86. self.step_in_cycle = self.step_in_cycle + 1
  87. if self.step_in_cycle >= self.cur_cycle_steps:
  88. self.cycle += 1
  89. self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
  90. self.cur_cycle_steps = (
  91. int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
  92. )
  93. elif epoch >= self.first_cycle_steps:
  94. if self.cycle_mult == 1.0:
  95. self.step_in_cycle = epoch % self.first_cycle_steps
  96. self.cycle = epoch // self.first_cycle_steps
  97. else:
  98. n: int = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
  99. self.cycle = n
  100. self.step_in_cycle = epoch - int(
  101. self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)
  102. ) # fmt: skip
  103. self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** n # fmt: skip
  104. else:
  105. self.cur_cycle_steps = self.first_cycle_steps
  106. self.step_in_cycle = epoch
  107.  
  108. self.max_lr = self.base_max_lr * (self.gamma ** self.cycle) # fmt: skip
  109. self.last_epoch = math.floor(epoch)
  110.  
  111. lrs = self.get_lr()
  112.  
  113. for param_group, lr in zip(self.optimizer.param_groups, lrs):
  114. param_group['lr'] = lr
  115.  
  116. self._last_lr = lrs
  117.  
  118. def get_last_lr(self) -> List[float]:
  119. # Если _last_lr еще не установлен — вернуть расчетный lr
  120. if hasattr(self, '_last_lr'):
  121. return self._last_lr
  122. else:
  123. return self.get_lr()
  124.  
  125. class CosineAnnealingWarmupRestartsPD(LRScheduler):
  126. r"""CosineAnnealingWarmupRestarts with per-param-group LR scaling.
  127.  
  128. Каждая группа параметров начинает с собственного начального lr,
  129. и scheduler понижает/повышает её относительно этих значений.
  130. """
  131.  
  132. def __init__(
  133. self,
  134. optimizer: Optimizer,
  135. first_cycle_steps: int,
  136. cycle_mult: float = 1.0,
  137. max_lr: float = 1e-4, # глобальный максимум для самой быстрой группы
  138. min_lr: float = 1e-6, # глобальный минимум для самой быстрой группы
  139. warmup_steps: int = 0,
  140. gamma: float = 0.9,
  141. last_epoch: int = -1,
  142. ):
  143. if warmup_steps >= first_cycle_steps:
  144. raise ValueError(
  145. f"warmup_steps must be smaller than first_cycle_steps ({warmup_steps} < {first_cycle_steps})"
  146. )
  147.  
  148. self.first_cycle_steps = first_cycle_steps
  149. self.cycle_mult = cycle_mult
  150. self.global_base_max_lr = max_lr
  151. self.global_min_lr = min_lr
  152. self.warmup_steps = warmup_steps
  153. self.gamma = gamma
  154. self.cur_cycle_steps = first_cycle_steps
  155. self.step_in_cycle = last_epoch
  156. self.last_epoch = last_epoch
  157. self.cycle = 0
  158.  
  159. # Сохраним начальные lrs для каждой группы
  160. self.init_base_lrs = [pg["lr"] for pg in optimizer.param_groups]
  161.  
  162. # Вычислим коэффициенты для масштабирования каждой группы
  163. # Самая быстрая группа == max_lr, остальные пропорционально
  164. max_init_lr = max(self.init_base_lrs)
  165. self.scale_factors = [init_lr / max_init_lr for init_lr in self.init_base_lrs]
  166.  
  167. super().__init__(optimizer, last_epoch)
  168.  
  169. def get_lr(self) -> List[float]:
  170. """Вычисляем LR для каждой группы, сохраняя пропорцию относительно изначальных значений."""
  171. if self.step_in_cycle == -1:
  172. return self.init_base_lrs
  173.  
  174. # Определяем текущий глобальный максимум (для самой быстрой группы)
  175. current_global_max = self.global_base_max_lr * (self.gamma ** self.cycle)
  176. current_global_min = self.global_min_lr
  177.  
  178. lrs = []
  179. for scale in self.scale_factors:
  180. # Персональные min/max для этой группы
  181. group_max = current_global_max * scale
  182. group_min = current_global_min * scale
  183.  
  184. if self.step_in_cycle < self.warmup_steps:
  185. # Линейный разгон
  186. lr = group_min + (group_max - group_min) * self.step_in_cycle / self.warmup_steps
  187. else:
  188. # Косинусное затухание
  189. lr = group_min + (group_max - group_min) * (
  190. 1 + math.cos(
  191. math.pi * (self.step_in_cycle - self.warmup_steps)
  192. / (self.cur_cycle_steps - self.warmup_steps)
  193. )
  194. ) / 2.0
  195.  
  196. lrs.append(lr)
  197.  
  198. return lrs
  199.  
  200. def step(self, epoch: Optional[int] = None):
  201. if epoch is None:
  202. epoch = self.last_epoch + 1
  203. self.step_in_cycle += 1
  204. if self.step_in_cycle >= self.cur_cycle_steps:
  205. self.cycle += 1
  206. self.step_in_cycle -= self.cur_cycle_steps
  207. self.cur_cycle_steps = (
  208. int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult)
  209. + self.warmup_steps
  210. )
  211. elif epoch >= self.first_cycle_steps:
  212. if self.cycle_mult == 1.0:
  213. self.step_in_cycle = epoch % self.first_cycle_steps
  214. self.cycle = epoch // self.first_cycle_steps
  215. else:
  216. n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
  217. self.cycle = n
  218. self.step_in_cycle = epoch - int(
  219. self.first_cycle_steps * (self.cycle_mult**n - 1) / (self.cycle_mult - 1)
  220. )
  221. self.cur_cycle_steps = self.first_cycle_steps * (self.cycle_mult**n)
  222. else:
  223. self.cur_cycle_steps = self.first_cycle_steps
  224. self.step_in_cycle = epoch
  225.  
  226. self.last_epoch = math.floor(epoch)
  227.  
  228. lrs = self.get_lr()
  229. for pg, lr in zip(self.optimizer.param_groups, lrs):
  230. pg["lr"] = lr
  231.  
  232. self._last_lr = lrs
  233.  
  234. def get_last_lr(self) -> List[float]:
  235. return getattr(self, "_last_lr", self.get_lr())
  236.  
  237.  
  238. class CosineAnnealingWarmupRestartsDTEFIX(LRScheduler):
  239. r"""CosineAnnealingWarmupRestarts with per-param-group LR scaling.
  240.  
  241. Каждая группа параметров начинает с собственного начального lr,
  242. и scheduler понижает/повышает её относительно этих значений.
  243. """
  244.  
  245. def __init__(
  246. self,
  247. optimizer: Optimizer,
  248. first_cycle_steps: int,
  249. cycle_mult: float = 1.0,
  250. max_lr: float = 1e-4, # глобальный максимум для самой быстрой группы
  251. min_lr: float = 1e-6, # глобальный минимум для самой быстрой группы
  252. warmup_steps: int = 0,
  253. gamma: float = 0.9,
  254. last_epoch: int = -1,
  255. ):
  256. if warmup_steps >= first_cycle_steps:
  257. raise ValueError(
  258. f"warmup_steps must be smaller than first_cycle_steps ({warmup_steps} < {first_cycle_steps})"
  259. )
  260.  
  261. self.first_cycle_steps = first_cycle_steps
  262. self.cycle_mult = cycle_mult
  263. self.global_base_max_lr = max_lr
  264. self.global_min_lr = min_lr
  265. self.warmup_steps = warmup_steps
  266. self.gamma = gamma
  267. self.cur_cycle_steps = first_cycle_steps
  268. self.step_in_cycle = last_epoch
  269. self.last_epoch = last_epoch
  270. self.cycle = 0
  271.  
  272. # Сохраним начальные lrs для каждой группы
  273. self.init_base_lrs = [pg["lr"] for pg in optimizer.param_groups]
  274.  
  275. # Вычислим коэффициенты для масштабирования каждой группы
  276. # Самая быстрая группа == max_lr, остальные пропорционально
  277. max_init_lr = max(self.init_base_lrs)
  278. self.scale_factors = [init_lr / max_init_lr for init_lr in self.init_base_lrs]
  279.  
  280. super().__init__(optimizer, last_epoch)
  281.  
  282. def get_lr(self) -> List[float]:
  283. if self.step_in_cycle == -1:
  284. return self.init_base_lrs
  285.  
  286. # max — затухает по gamma, min — фиксированный
  287. current_global_max = self.global_base_max_lr * (self.gamma ** self.cycle)
  288. current_global_min = self.global_min_lr # теперь не уменьшается
  289.  
  290. lrs = []
  291. for scale in self.scale_factors:
  292. group_max = current_global_max * scale
  293. group_min = self.global_min_lr # не масштабируем вниз
  294.  
  295. if self.step_in_cycle < self.warmup_steps:
  296. lr = group_min + (group_max - group_min) * self.step_in_cycle / self.warmup_steps
  297. else:
  298. lr = group_min + (group_max - group_min) * (
  299. 1 + math.cos(
  300. math.pi * (self.step_in_cycle - self.warmup_steps)
  301. / (self.cur_cycle_steps - self.warmup_steps)
  302. )
  303. ) / 2.0
  304.  
  305. # гарантируем, что ниже глобального минимума не уйдёт
  306. lr = max(lr, self.global_min_lr)
  307. lrs.append(lr)
  308.  
  309. return lrs
  310.  
  311. def step(self, epoch: Optional[int] = None):
  312. if epoch is None:
  313. epoch = self.last_epoch + 1
  314. self.step_in_cycle += 1
  315. if self.step_in_cycle >= self.cur_cycle_steps:
  316. self.cycle += 1
  317. self.step_in_cycle -= self.cur_cycle_steps
  318. self.cur_cycle_steps = (
  319. int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult)
  320. + self.warmup_steps
  321. )
  322. elif epoch >= self.first_cycle_steps:
  323. if self.cycle_mult == 1.0:
  324. self.step_in_cycle = epoch % self.first_cycle_steps
  325. self.cycle = epoch // self.first_cycle_steps
  326. else:
  327. n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
  328. self.cycle = n
  329. self.step_in_cycle = epoch - int(
  330. self.first_cycle_steps * (self.cycle_mult**n - 1) / (self.cycle_mult - 1)
  331. )
  332. self.cur_cycle_steps = self.first_cycle_steps * (self.cycle_mult**n)
  333. else:
  334. self.cur_cycle_steps = self.first_cycle_steps
  335. self.step_in_cycle = epoch
  336.  
  337. self.last_epoch = math.floor(epoch)
  338.  
  339. lrs = self.get_lr()
  340. for pg, lr in zip(self.optimizer.param_groups, lrs):
  341. pg["lr"] = lr
  342.  
  343. self._last_lr = lrs
  344.  
  345. def get_last_lr(self) -> List[float]:
  346. return getattr(self, "_last_lr", self.get_lr())
  347.  
  348.  
  349. class CosineAnnealingWarmupRestartsMINLRFIX(LRScheduler):
  350. r"""CosineAnnealingWarmupRestarts with per-param-group LR scaling.
  351. min_lr фиксирован и не уменьшается по циклам.
  352. """
  353.  
  354. def __init__(
  355. self,
  356. optimizer: Optimizer,
  357. first_cycle_steps: int,
  358. cycle_mult: float = 1.0,
  359. max_lr: float = 1e-4, # глобальный максимум для самой быстрой группы
  360. min_lr: float = 1e-6, # фиксированный глобальный минимум
  361. warmup_steps: int = 0,
  362. gamma: float = 0.9,
  363. last_epoch: int = -1,
  364. ):
  365. if warmup_steps >= first_cycle_steps:
  366. raise ValueError(
  367. f"warmup_steps must be smaller than first_cycle_steps ({warmup_steps} < {first_cycle_steps})"
  368. )
  369.  
  370. self.first_cycle_steps = first_cycle_steps
  371. self.cycle_mult = cycle_mult
  372. self.global_base_max_lr = max_lr
  373. self.global_min_lr = min_lr # фиксируем
  374. self.warmup_steps = warmup_steps
  375. self.gamma = gamma
  376. self.cur_cycle_steps = first_cycle_steps
  377. self.step_in_cycle = last_epoch
  378. self.last_epoch = last_epoch
  379. self.cycle = 0
  380.  
  381. # начальные lrs
  382. self.init_base_lrs = [pg["lr"] for pg in optimizer.param_groups]
  383.  
  384. # масштаб для групп
  385. max_init_lr = max(self.init_base_lrs)
  386. self.scale_factors = [init_lr / max_init_lr for init_lr in self.init_base_lrs]
  387.  
  388. super().__init__(optimizer, last_epoch)
  389.  
  390. def get_lr(self) -> List[float]:
  391. if self.step_in_cycle == -1:
  392. return self.init_base_lrs
  393.  
  394. current_global_max = self.global_base_max_lr * (self.gamma ** self.cycle)
  395. current_global_min = self.global_min_lr # фиксирован
  396.  
  397. lrs = []
  398. for scale in self.scale_factors:
  399. group_max = current_global_max * scale
  400. group_min = current_global_min * scale
  401.  
  402. if self.step_in_cycle < self.warmup_steps:
  403. lr = group_min + (group_max - group_min) * self.step_in_cycle / self.warmup_steps
  404. else:
  405. lr = group_min + (group_max - group_min) * (
  406. 1 + math.cos(
  407. math.pi * (self.step_in_cycle - self.warmup_steps)
  408. / (self.cur_cycle_steps - self.warmup_steps)
  409. )
  410. ) / 2.0
  411.  
  412. # гарантируем, что lr не упадёт ниже group_min
  413. lrs.append(max(lr, group_min))
  414.  
  415. return lrs
  416.  
  417. def step(self, epoch: Optional[int] = None):
  418. if epoch is None:
  419. epoch = self.last_epoch + 1
  420. self.step_in_cycle += 1
  421. if self.step_in_cycle >= self.cur_cycle_steps:
  422. self.cycle += 1
  423. self.step_in_cycle -= self.cur_cycle_steps
  424. self.cur_cycle_steps = (
  425. int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult)
  426. + self.warmup_steps
  427. )
  428. elif epoch >= self.first_cycle_steps:
  429. if self.cycle_mult == 1.0:
  430. self.step_in_cycle = epoch % self.first_cycle_steps
  431. self.cycle = epoch // self.first_cycle_steps
  432. else:
  433. n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
  434. self.cycle = n
  435. self.step_in_cycle = epoch - int(
  436. self.first_cycle_steps * (self.cycle_mult**n - 1) / (self.cycle_mult - 1)
  437. )
  438. self.cur_cycle_steps = self.first_cycle_steps * (self.cycle_mult**n)
  439. else:
  440. self.cur_cycle_steps = self.first_cycle_steps
  441. self.step_in_cycle = epoch
  442.  
  443. self.last_epoch = math.floor(epoch)
  444.  
  445. lrs = self.get_lr()
  446. for pg, lr in zip(self.optimizer.param_groups, lrs):
  447. pg["lr"] = lr
  448.  
  449. self._last_lr = lrs
  450.  
  451. def get_last_lr(self) -> List[float]:
  452. return getattr(self, "_last_lr", self.get_lr())
  453.  
Advertisement
Add Comment
Please, Sign In to add comment