Guest User

normalized weight decay

a guest
May 5th, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.85 KB | None | 0 0
  1. def get_norm_wds(w_norm, num_cycles, cycle_len, cycle_mult, bs, ds):
  2.     """
  3.    Returns a list of normalized weight decay factors.
  4.    
  5.    Arguments:
  6.        w_norm: float, weight decay if only one batch pass is allowed
  7.        num_cycles: int, number of cycles of model fitting (can include multiple epochs per cycle)
  8.        cycle_len: int
  9.        cycle_mult: int
  10.        bs: batch size
  11.        ds: total size of training data per epoch
  12.    """
  13.     epochs = []
  14.     for i in range(num_cycles):
  15.         # returns number of epochs in current cycle (before restart)
  16.         epochs.append(cycle_mult ** (i) * cycle_len)
  17.    
  18.     cycle_wds = []
  19.     for i in range(num_cycles):
  20.         #returns weight decay per cycle
  21.         cycle_wds.append(w_norm*math.sqrt(bs/(ds*epochs[i])))
  22.    
  23.     wds = np.repeat(cycle_wds,epochs)
  24.     return wds.tolist()
Add Comment
Please, Sign In to add comment