jack06215

[keras] cyclic learning rate

May 28th, 2020
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.88 KB | None | 0 0
  1. class TriangularSchedule():
  2.     def __init__(self, min_lr, max_lr, cycle_length, inc_fraction=0.5):    
  3.         self.min_lr = min_lr
  4.         self.max_lr = max_lr
  5.         self.cycle_length = cycle_length
  6.         self.inc_fraction = inc_fraction
  7.        
  8.     def __call__(self, iteration):
  9.         if iteration <= self.cycle_length*self.inc_fraction:
  10.             unit_cycle = iteration * 1 / (self.cycle_length * self.inc_fraction)
  11.         elif iteration <= self.cycle_length:
  12.             unit_cycle = (self.cycle_length - iteration) * 1 / (self.cycle_length * (1 - self.inc_fraction))
  13.         else:
  14.             unit_cycle = 0
  15.         adjusted_cycle = (unit_cycle * (self.max_lr - self.min_lr)) + self.min_lr
  16.         return adjusted_cycle
  17.  
  18. class CyclicalSchedule():
  19.     def __init__(self, schedule_class, cycle_length, cycle_length_decay=1, cycle_magnitude_decay=1, **kwargs):
  20.         self.schedule_class = schedule_class
  21.         self.length = cycle_length
  22.         self.length_decay = cycle_length_decay
  23.         self.magnitude_decay = cycle_magnitude_decay
  24.         self.kwargs = kwargs
  25.    
  26.     def __call__(self, iteration):
  27.         cycle_idx = 0
  28.         cycle_length = self.length
  29.         idx = self.length
  30.         while idx <= iteration:
  31.             cycle_length = math.ceil(cycle_length * self.length_decay)
  32.             cycle_idx += 1
  33.             idx += cycle_length
  34.         cycle_offset = iteration - idx + cycle_length
  35.        
  36.         schedule = self.schedule_class(cycle_length=cycle_length, **self.kwargs)
  37.         return schedule(cycle_offset) * self.magnitude_decay**cycle_idx
  38.  
  39.  
  40. schedule = CyclicalSchedule(TriangularSchedule, min_lr=0.5, max_lr=2, cycle_length=500)
  41. iterations=2000
  42.  
  43. plt.plot([i+1 for i in range(iterations)],[schedule(i) for i in range(iterations)])
  44. plt.title('Learning rate for each epoch')
  45. plt.xlabel("Epoch")
  46. plt.ylabel("Learning Rate")
  47. plt.show()
Add Comment
Please, Sign In to add comment