Advertisement
Guest User

Untitled

a guest
Oct 24th, 2024
48
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.46 KB | Software | 0 0
  1. import torch
  2. import yaml
  3. import logging
  4. import onnx
  5. from google.protobuf.json_format import MessageToDict
  6.  
  7. # import self-made modules
  8. import utils
  9. import pruning
  10. import quantization
  11. import ImageDataLoader
  12. import test_suite_model
  13. import train_maskrcnn
  14. # ---------------------------------------------------------------------------- #
  15. def main():
  16.     #logging setup
  17.     utils.setup_logging()
  18.    
  19.     # Load the config file
  20.     with open("config.yaml", "r") as file:
  21.         config = yaml.safe_load(file)
  22.  
  23.     # Load the model
  24.  
  25.     model = utils.load_model_to_pytorch(pretrained = True) #pretrained uses the pytorch pretrained model
  26.    
  27. # ---------------------------------------------------------------------------- #
  28.     #Data, fine-tune the model on a small number of images and their annotations
  29.     train_dataloader  = ImageDataLoader.return_dataloader(config["datasets"]["train_data_path"],
  30.                                                           config["datasets"]["train_annotation_path"])    
  31.     val_dataloader  = ImageDataLoader.return_dataloader(config["datasets"]["val_data_path"],
  32.                                                         config["datasets"]["val_annotation_path"])
  33.     test_dataloader = ImageDataLoader.return_dataloader(config["datasets"]["test_data_path"],
  34.                                                         config["datasets"]["test_annotation_path"])
  35.    
  36.     #Check if the dataloader is working correctly by visualizing a sample
  37.     utils.visualize_one_dataloader_sample(train_dataloader)
  38.    
  39.     model,train_losses_head, val_losses_head = train_maskrcnn.train_finetune_head(model, train_dataloader, val_dataloader, device=config["device"], num_epochs=10) #train only the detection head
  40.     #model, train_losses_backbone, val_losses_backbone = train_maskrcnn.train_backbone(model, train_dataloader, val_dataloader, device=config["device"], num_epochs=5) #train the whole model
  41.    
  42.     #utils.plot_combined_losses(train_losses_head, train_losses_backbone,val_losses_head, val_losses_backbone)
  43.    
  44. # ---------------------------------------------------------------------------- #
  45.     # Set up naming components
  46.     name_components = []
  47.  
  48.     # Apply compression techniques
  49.     if config["compression"]["prune"]["enabled"]:
  50.         logging.info("Applying pruning")
  51.  
  52.         amount = config["compression"]["prune"]["amount"]
  53.         model = pruning.apply_local_pruning(model, amount)
  54.         name_components.append(f"pruned_{amount}")
  55.  
  56.     if config["compression"]["quantize"]["enabled"]:
  57.         logging.info("Applying quantization")
  58.        
  59.         bit_precision = config["compression"]["quantize"]["bit_precision"]
  60.         model = quantization.apply_quantization(model, bit_precision)
  61.         name_components.append(f"quantized_{bit_precision}bit")
  62.  
  63.     if config["compression"]["distill"]["enabled"]:
  64.         logging.info("Applying distillation")
  65.        
  66.         name_components.append("distilled")
  67.         pass
  68.    
  69. # ---------------------------------------------------------------------------- #
  70.     # Generate & save the ONNX model name based on the compression techniques
  71.     base_model_name = "MRCNN_base"
  72.     model_name = f"{base_model_name}_{'_'.join(name_components)}"
  73.     model_full_file_path = f"./models/{model_name}.onnx" #use the name without the .onnx extension
  74.    
  75.     torch_input = torch.randn(1, 3, 1216, 1368)
  76.     utils.save_model_onnx(model_full_file_path,model, torch_input)
  77.     """
  78.    IMPORTANT:
  79.    input: [batch, channels, height, width]
  80.    output: boxes, labels, scores, masks
  81.    """
  82. # ---------------------------------------------------------------------------- # WHERE IT GOES WRONG
  83.    
  84.     # Run inference on a single image
  85.     image_path = "./images/train/img_ (1).bmp"
  86.     onnx_model_path = "./models/MRCNN_base_.onnx"    
  87.    
  88.     utils.print_model_info("./models/MRCNN_base_.onnx")
  89.  
  90.     results = utils.run_inference(image_path, onnx_model_path, conf_threshold=0.5)
  91.     utils.visualize_detections(results)
  92.  
  93. # ---------------------------------------------------------------------------- #
  94.     # Run the test suite for both cpu and gpu
  95.     # metrics consist of: accuracy, F1, precision, recall, jaccard index, auc_roc
  96.     #model_full_file_path = "./models/MRCNN_base_.onnx"
  97.     avg_cpu_inference_time, cpu_metrics, all_ground_truths_cpu, all_predictions_cpu = test_suite_model.run_cpu_tests(model_full_file_path, test_dataloader)
  98.     avg_gpu_inference_time, gpu_metrics, all_ground_truths_gpu, all_predictions_gpu = test_suite_model.run_gpu_tests(model_full_file_path, test_dataloader)
  99.     print(avg_cpu_inference_time, cpu_metrics, all_ground_truths_cpu, all_predictions_cpu)
  100.        
  101.     # Extract metrics
  102.     cpu_accuracy,cpu_f1_score,cpu_precision,cpu_recall,cpu_jaccard_index,cpu_auc_roc = cpu_metrics["accuracy"],cpu_metrics["precision"],cpu_metrics["recall"],cpu_metrics["jaccard_index"],cpu_metrics["auc_roc"]
  103.     gpu_accuracy,gpu_f1_score,gpu_precision,gpu_recall,gpu_jaccard_index,gpu_auc_roc = gpu_metrics["accuracy"],gpu_metrics["precision"],gpu_metrics["recall"],gpu_metrics["jaccard_index"],gpu_metrics["auc_roc"]
  104. # ---------------------------------------------------------------------------- #
  105.     # save the results
  106.     results_file_path = f"./results/{model_name}" #use the name without the .onnx extension
  107.     #utils.write_out_results(results_file_path, config, inference_time, accuracy_score)
  108.    
  109.     logging.info("program finished")
  110.  
  111. if __name__ == "__main__":
  112.     main()
  113.  
Tags: python
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement