Advertisement
Guest User

Ray

a guest
Nov 4th, 2021
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.29 KB | None | 0 0
  1. def train_data_tune_checkpoint(config,data,
  2.                                 checkpoint_dir=None,
  3.                                 num_epochs=100,
  4.                                 num_gpus=1,
  5.                                 ):
  6.     kwargs = {
  7.         "max_epochs": num_epochs,
  8.         # If fractional GPUs passed in, convert to int.
  9.         "gpus": math.ceil(num_gpus),
  10.         "logger": TensorBoardLogger(
  11.             save_dir=tune.get_trial_dir(), name="", version="."),
  12.         "progress_bar_refresh_rate": 0,
  13.         "callbacks": [
  14.             TuneReportCheckpointCallback(
  15.                 metrics={
  16.                     "loss": "ptl/val_loss",
  17.                 },
  18.                 filename="checkpoint",
  19.                 on="validation_end")
  20.         ]
  21.     }
  22.  
  23.     if checkpoint_dir:
  24.         kwargs["resume_from_checkpoint"] = os.path.join(
  25.             checkpoint_dir, "checkpoint")
  26.        
  27.     hres = monai.networks.nets.HighResNet(
  28.     spatial_dims=3,
  29.     in_channels=1,
  30.     out_channels=2,
  31.     acti_type="prelu"
  32.         )
  33.  
  34.     model = Model(
  35.     net=hres,
  36.     criterion=monai.losses.FocalLoss(softmax=True),
  37.     learning_rate=5e-4,
  38.     optimizer_class=torch.optim.AdamW,
  39.         )
  40.    
  41.     trainer = pl.Trainer(**kwargs)
  42.  
  43.     trainer.fit(model,datamodule=data)
  44.  
  45. def tune_data_pbt(data, num_samples=10, num_epochs=10, gpus_per_trial=0):
  46.     config = {
  47.         "lr": 1e-3,
  48.     }
  49.  
  50.     scheduler = PopulationBasedTraining(
  51.         perturbation_interval=4,
  52.         hyperparam_mutations={
  53.             "lr": tune.loguniform(1e-5, 1e-1),
  54.         })
  55.  
  56.     reporter = CLIReporter(
  57.         parameter_columns=["lr"],
  58.         metric_columns=["loss"])
  59.  
  60.     analysis = tune.run(
  61.         tune.with_parameters(
  62.             train_prostatex_tune_checkpoint,
  63.             num_epochs=num_epochs,
  64.             num_gpus=gpus_per_trial,
  65.             data=data),
  66.         resources_per_trial={
  67.             "cpu": 1,
  68.             "gpu": gpus_per_trial
  69.         },
  70.         metric="loss",
  71.         mode="min",
  72.         config=config,
  73.         num_samples=num_samples,
  74.         scheduler=scheduler,
  75.         progress_reporter=reporter,
  76.         name="tune_mnist_pbt")
  77.  
  78.     print("Best hyperparameters found were: ", analysis.best_config)
  79.  
  80.  
  81. tune_data_pbt(data, num_samples=10, num_epochs=100, gpus_per_trial=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement