zzzdrop

lstm_ray.py

Jul 10th, 2024
56
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.36 KB | Source Code | 0 0
  1. #LSTM-RAY.py
  2. #The import of pytorch_lightning must be changed to lightning.pytorch in
  3. #darts/models/forecasting in pl_forecasting_module.py and torch_forecasting_module.py
  4.  
  5. import pickle
  6. import ray
  7. from ray import tune, train
  8. from ray.tune import CLIReporter
  9. from ray.tune.schedulers import ASHAScheduler
  10. from darts import TimeSeries
  11. from darts.models import RNNModel
  12. from darts.dataprocessing.transformers import Scaler
  13. from darts.metrics import rmse, smape
  14. import lightning.pytorch as pl
  15. from lightning.pytorch.callbacks import EarlyStopping
  16. from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
  17. from ray.train.lightning import RayTrainReportCallback, RayLightningEnvironment, RayDDPStrategy, prepare_trainer
  18. from torch.nn import MSELoss
  19. from lightning.pytorch import Trainer
  20. import ray.train.lightning
  21. from ray.train.context import TrainContext
  22.  
  23. from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
  24. from lightning.pytorch.plugins.environments import SLURMEnvironment
  25. import torch
  26. from torch import nn, Tensor
  27. import torch.optim.lr_scheduler as lr
  28.  
  29. from torchmetrics import MeanAbsoluteError, SymmetricMeanAbsolutePercentageError, MetricCollection
  30. from ray.train import RunConfig, ScalingConfig, CheckpointConfig, FailureConfig
  31. from ray.train.torch import TorchTrainer, TorchConfig
  32. import os
  33.  
  34. #os.environ["CUDA_VISIBLE_DEVICES"]=str(torch.cuda.device_count())
  35.  
  36.  
  37. def loadScaledDataset():
  38.     # Returns scaler, trianListScaled, valListScaled
  39.     return pickle.load(open('large_reduced_scaler_train_val_scaled_dfs.pkl', 'rb'))
  40.  
  41. def train_model(config):
  42.     train = ray.get(config["train_ref"])
  43.     val = ray.get(config["val_ref"])
  44.  
  45.  
  46.     MODELNAME = "lstm_optim_1"
  47.  
  48.     NUM_NODES = 1
  49.     torch.set_float32_matmul_precision('high')
  50.     accel = "cuda" if torch.cuda.is_available() else "cpu"
  51.     print("torch device: " + str(torch.cuda.current_device()))
  52.  
  53.     modelParams = {
  54.         "input_chunk_length":config["input_chunk_length"],
  55.         "training_length":config["input_chunk_length"],
  56.         "model":"LSTM",
  57.         "hidden_dim":config["hidden_dim"],
  58.         "n_rnn_layers":config["n_rnn_layers"],
  59.         "dropout":config["dropout"],
  60.         "batch_size":config["batch_size"],
  61.         "n_epochs":config["n_epochs"],
  62.         "loss_fn":nn.HuberLoss(),
  63.         #"loss_fn":rmse, #see if this works
  64.         "optimizer_cls":torch.optim.AdamW,
  65.         #"random_state":42,
  66.         "log_tensorboard": True,
  67.         "save_checkpoints": True,
  68.         "force_reset":True
  69.     }
  70.     pl_trainer_kwargs={
  71.             "devices": "auto",
  72.             "num_nodes": NUM_NODES,
  73.             "gradient_clip_val": 0.1,
  74.             "max_epochs":config["n_epochs"],
  75.             "default_root_dir": "ckpts/",
  76.     }
  77.  
  78.     initial_trainer_kwargs = pl_trainer_kwargs
  79.     initial_trainer = Trainer(**initial_trainer_kwargs,
  80.                             strategy=ray.train.lightning.RayDDPStrategy(),
  81.                             plugins=[ray.train.lightning.RayLightningEnvironment()],
  82.     callbacks=[ray.train.lightning.RayTrainReportCallback()],
  83.                             enable_checkpointing=False)
  84.     initial_trainer = prepare_trainer(initial_trainer)
  85.  
  86.     initial_model = RNNModel(
  87.         **modelParams,
  88.     )
  89.  
  90.     # Find the optimal learning rate
  91.     lr_finder = initial_model.lr_find(series=train, val_series=val, trainer=initial_trainer)
  92.     base_lr = lr_finder.suggestion()
  93.     max_lr = 4 * base_lr
  94.  
  95.     model = RNNModel(
  96.         **modelParams,
  97.         optimizer_kwargs = {
  98.             'lr': base_lr,
  99.             'weight_decay': config["weight_decay"]
  100.         },
  101.         lr_scheduler_cls = lr.CyclicLR,
  102.         lr_scheduler_kwargs={
  103.             "base_lr":base_lr,
  104.             "max_lr":base_lr * config["max_lr_scale"],
  105.             "mode": 'exp_range',
  106.             'gamma': 0.9,
  107.             'cycle_momentum':False
  108.         }
  109.     )
  110.     print("training")
  111.     model.fit(series=train, val_series=val, trainer=initial_trainer)
  112.     # Code will only get here if it's not terminated by ray's time limit
  113.     #predict over the 2nd half of the validation dataset
  114.     val_true = val.drop_after(0.5)
  115.     val_pred = model.predict(n=val_true.n_timesteps -1, series = val_true, trainer=initial_trainer)
  116.     print("reporting")
  117.     errors = {"rmse":rmse(val, val_pred), "smape":smape(val,val_pred)}
  118.     print(config)
  119.     print(errors)
  120.  
  121. def get_ray_train_configs(config: dict=None, num_samples: int=16):
  122.     if config is None: raise Exception("No config passed!")
  123.  
  124.     num_devices = torch.cuda.device_count()
  125.     num_devices = 1 if num_devices==0 else num_devices
  126.     print("num_devices: " + str(num_devices))
  127.  
  128.     scaling_config = ScalingConfig(
  129.         num_workers= num_devices, use_gpu=True,
  130.         accelerator_type="A100",
  131.         #resources_per_worker={"CPU": 2, "GPU": 1}
  132.     )
  133.     col_list = list(config.keys());col_list.remove("train_ref");col_list.remove("val_ref")
  134.     reporter = CLIReporter(
  135.         parameter_columns=col_list,
  136.         metric_columns=["error", "training_iteration"],
  137.         metric="error", mode=min
  138.     )
  139.     run_config = RunConfig(
  140.         checkpoint_config=CheckpointConfig(
  141.             num_to_keep=2,
  142.             checkpoint_score_attribute="val_loss",
  143.             checkpoint_score_order="min",
  144.         ),
  145.         failure_config = FailureConfig(max_failures=1), #Should retry failed runs once
  146.         storage_path = "/home/farooqzahid/transfer1/ray_results"
  147.     )
  148.  
  149.     scheduler = ASHAScheduler(
  150.         #time_attr = "time_total_s" or "total_time_s" #time option
  151.         time_attr = "time_total_s",
  152.         metric="val_loss", #Scaled data, so average rmse is similar to average mape in application
  153.         mode="min",
  154.         max_t=80000, #default t is number of iterations #SET FOR ARC
  155.         grace_period=20,
  156.         reduction_factor=2
  157.     )
  158.    ray_trainer = TorchTrainer(
  159.         train_loop_per_worker=train_model,
  160.         torch_config=TorchConfig(backend="gloo"),
  161.         scaling_config=scaling_config,
  162.         run_config=run_config
  163.     )
  164.  
  165.     tune_config = tune.TuneConfig(
  166.         #metric = "HuberLoss", mode = "min", #might need to be HuberLoss OR error
  167.         scheduler=scheduler,
  168.         num_samples=num_samples,
  169.         max_concurrent_trials=8
  170.         #time_budget_s = 000 #Can set this on ARC
  171.     )
  172.  
  173.     tuner = tune.Tuner(
  174.         ray_trainer,
  175.         param_space={"train_loop_config":config},
  176.         tune_config=tune_config
  177.     )
  178.  
  179.     return tuner
  180.  
  181. def main(filepath=None, time_col=None, value_col=None):
  182.  
  183.     scaler, train, val = loadScaledDataset()
  184.     del scaler
  185.     #ray.init(ignore_reinit_error=True)
  186.  
  187.     train_ref = ray.put(train)
  188.     val_ref = ray.put(val)
  189.     config = {
  190.         "input_chunk_length": tune.choice([24, 48, 92]),
  191.         #"training_length":tune.sample_from(lambda spec: spec.config.input_chunk_length),
  192.         "hidden_dim": tune.choice([10, 25, 50]),
  193.         "n_rnn_layers": tune.choice([16, 128, 256]),
  194.         "dropout": tune.uniform(0, 0.4),
  195.         "batch_size": tune.choice([32, 256, 1024]),
  196.         "n_epochs": tune.randint(50, 2000),
  197.         "max_lr_scale": tune.uniform(1, 5),
  198.         "weight_decay": tune.choice([1e-3, 1e-5, 1e-7]),
  199.         "train_ref":train_ref,
  200.         "val_ref":val_ref
  201.     }
  202.  
  203.     tuner = get_ray_train_configs(config = config, num_samples= 16)
  204.  
  205.     results = tuner.fit()
  206.     return results
  207.  
  208. if __name__ == main():
  209.     main()
  210.  
Advertisement
Add Comment
Please, Sign In to add comment