Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_norm_wds(w_norm, num_cycles, cycle_len, cycle_mult, bs, ds):
- """
- Returns a list of normalized weight decay factors.
- Arguments:
- w_norm: float, weight decay if only one batch pass is allowed
- num_cycles: int, number of cycles of model fitting (can include multiple epochs per cycle)
- cycle_len: int
- cycle_mult: int
- bs: batch size
- ds: total size of training data per epoch
- """
- epochs = []
- for i in range(num_cycles):
- # returns number of epochs in current cycle (before restart)
- epochs.append(cycle_mult ** (i) * cycle_len)
- cycle_wds = []
- for i in range(num_cycles):
- #returns weight decay per cycle
- cycle_wds.append(w_norm*math.sqrt(bs/(ds*epochs[i])))
- wds = np.repeat(cycle_wds,epochs)
- return wds.tolist()
Add Comment
Please, Sign In to add comment