Guest User

Untitled

a guest
Oct 22nd, 2017
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.76 KB | None | 0 0
  1. def list_rate_schedule(lrates, output=True):
  2. sched = []
  3. last_lr = [0]
  4.  
  5. for lr, n in lrates:
  6. sched += [lr] * n
  7.  
  8. def lr_sched(epoch):
  9. lr = sched[-1]
  10. if epoch < len(sched):
  11. lr = sched[epoch]
  12. if output and lr != last_lr[0]:
  13. print('Learning rate: {}'.format(lr))
  14. last_lr[0] = lr
  15. return lr
  16.  
  17. return lr_sched
  18.  
  19. # usage ----------------------------------------------------
  20.  
  21. rates = [
  22. (1e-5, 2), # 2 epochs @ 1e-5
  23. (1e-3, 4), # 4 epochs @ 1e-3
  24. (1e-4, 8), # 8 epochs @ 1e-4
  25. # ...
  26. ]
  27.  
  28. lrsched = keras.callbacks.LearningRateScheduler(
  29. list_rate_schedule(rates))
  30.  
  31. model.fit_generator(
  32. trn_batches,
  33. trn_batches.samples // trn_batches.batch_size,
  34. callacks=[lrsched],
  35. # ...
  36. )
Add Comment
Please, Sign In to add comment