Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_data_tune_checkpoint(config,data,
- checkpoint_dir=None,
- num_epochs=100,
- num_gpus=1,
- ):
- kwargs = {
- "max_epochs": num_epochs,
- # If fractional GPUs passed in, convert to int.
- "gpus": math.ceil(num_gpus),
- "logger": TensorBoardLogger(
- save_dir=tune.get_trial_dir(), name="", version="."),
- "progress_bar_refresh_rate": 0,
- "callbacks": [
- TuneReportCheckpointCallback(
- metrics={
- "loss": "ptl/val_loss",
- },
- filename="checkpoint",
- on="validation_end")
- ]
- }
- if checkpoint_dir:
- kwargs["resume_from_checkpoint"] = os.path.join(
- checkpoint_dir, "checkpoint")
- hres = monai.networks.nets.HighResNet(
- spatial_dims=3,
- in_channels=1,
- out_channels=2,
- acti_type="prelu"
- )
- model = Model(
- net=hres,
- criterion=monai.losses.FocalLoss(softmax=True),
- learning_rate=5e-4,
- optimizer_class=torch.optim.AdamW,
- )
- trainer = pl.Trainer(**kwargs)
- trainer.fit(model,datamodule=data)
- def tune_data_pbt(data, num_samples=10, num_epochs=10, gpus_per_trial=0):
- config = {
- "lr": 1e-3,
- }
- scheduler = PopulationBasedTraining(
- perturbation_interval=4,
- hyperparam_mutations={
- "lr": tune.loguniform(1e-5, 1e-1),
- })
- reporter = CLIReporter(
- parameter_columns=["lr"],
- metric_columns=["loss"])
- analysis = tune.run(
- tune.with_parameters(
- train_prostatex_tune_checkpoint,
- num_epochs=num_epochs,
- num_gpus=gpus_per_trial,
- data=data),
- resources_per_trial={
- "cpu": 1,
- "gpu": gpus_per_trial
- },
- metric="loss",
- mode="min",
- config=config,
- num_samples=num_samples,
- scheduler=scheduler,
- progress_reporter=reporter,
- name="tune_mnist_pbt")
- print("Best hyperparameters found were: ", analysis.best_config)
- tune_data_pbt(data, num_samples=10, num_epochs=100, gpus_per_trial=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement