Guest User

Untitled

a guest
Jul 14th, 2020
430
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.06 KB | None | 0 0
  1. from learning_functions.create_partition import create_partition_and_labels
  2. from learning_functions.data_generator import DataGenerator
  3. from keras.callbacks import ModelCheckpoint, TensorBoard
  4. import datetime
  5. import os
  6. import inspect
  7. import gc
  8.  
  9.  
  10. def perform_learning(training_sample_dir, val_sample_dir,
  11. batch_size, three_d, sample_channels, categorise, output_classes, shuffle,
  12. model_func, model_params, epochs, model_path, checkpoint_path,
  13. log_name):
  14.  
  15. # create partition
  16. partition, labels = create_partition_and_labels(training_sample_dir, val_sample_dir)
  17.  
  18. # generators
  19. params = {'batch_size': batch_size,
  20. 'three_d': three_d,
  21. 'n_channels': sample_channels,
  22. 'categorise': categorise,
  23. 'n_classes': output_classes,
  24. 'shuffle': shuffle}
  25.  
  26. training_generator = DataGenerator(partition['train'], labels, training_sample_dir, **params)
  27. validation_generator = DataGenerator(partition['validation'], labels, val_sample_dir, **params)
  28.  
  29. # create checkpoint path
  30. cropped_path = checkpoint_path[:checkpoint_path.rfind('/')]
  31. if not os.path.exists(cropped_path):
  32. os.makedirs(cropped_path)
  33.  
  34. # set checkpoint
  35. checkpoint = ModelCheckpoint(checkpoint_path, period=3)
  36.  
  37. # create model
  38. model = model_func(**model_params)
  39.  
  40. # tensorboard
  41. now = datetime.datetime.now()
  42. tensorboard_name = now.strftime("%Y-%m-%d-%H:%M")
  43. tensorboard_name = log_name + '-' + tensorboard_name
  44. path = "logs/" + tensorboard_name
  45. tensorboard = TensorBoard(log_dir=path)
  46.  
  47. # create description file
  48. if not os.path.exists(path):
  49. os.makedirs(path)
  50.  
  51. # train the mode
  52. model.fit_generator(generator=training_generator,
  53. validation_data=validation_generator,
  54. use_multiprocessing=True,
  55. workers=6,
  56. epochs=epochs,
  57. callbacks=[checkpoint, tensorboard])
  58.  
  59. model.save(model_path)
  60.  
  61. del model
  62. gc.collect()
Advertisement
Add Comment
Please, Sign In to add comment