Guest User

Monok first ML test

a guest
Jul 26th, 2025
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.15 KB | None | 0 0
  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3. from torchvision import transforms, models
  4. from torchvision.transforms import functional as TF
  5. from PIL import Image
  6. import scipy.io
  7. import os
  8. import matplotlib.pyplot as plt
  9. import torch.nn as nn
  10. import copy
  11. import json
  12. import numpy as np
  13. from sklearn.metrics import precision_score, recall_score, f1_score
  14. import seaborn as sns
  15.  
  16. # === CUSTOM DATASET CLASS ===
  17. class OxfordFlowersDataset(Dataset):
  18. def __init__(self, images_folder, split_file, labels_file, split='train', transform=None):
  19. self.images_folder = images_folder
  20. self.labels = scipy.io.loadmat(labels_file)['labels'][0]
  21. self.transform = transform
  22.  
  23. splits = scipy.io.loadmat(split_file)
  24. if split == 'train':
  25. self.indices = splits['trnid'][0]
  26. elif split == 'val':
  27. self.indices = splits['valid'][0]
  28. elif split == 'test':
  29. self.indices = splits['tstid'][0]
  30. else:
  31. raise ValueError(f"Unknown split: {split}")
  32.  
  33. def __len__(self):
  34. return len(self.indices)
  35.  
  36. def __getitem__(self, idx):
  37. real_idx = self.indices[idx]
  38. img_path = os.path.join(self.images_folder, f'image_{real_idx:05d}.jpg')
  39. image = Image.open(img_path).convert('RGB')
  40. label = self.labels[real_idx - 1] - 1 # Convert 1-based to 0-based
  41. if self.transform:
  42. image = self.transform(image)
  43. return image, label
  44.  
  45.  
  46. def main():
  47. # === SET DEVICE ===
  48. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  49. os.makedirs("logs/plots", exist_ok=True)
  50. os.makedirs("checkpoints", exist_ok=True)
  51.  
  52. # === TRANSFORMATIONS ===
  53. transform = transforms.Compose([
  54. transforms.Resize((224, 224)),
  55. transforms.ToTensor(),
  56. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  57. std=[0.229, 0.224, 0.225])
  58. ])
  59.  
  60. # === PATH CONFIG ===
  61. images_folder = 'dataset/102flowers'
  62. labels_file = 'dataset/imagelabels.mat'
  63. split_file = 'setid.mat'
  64.  
  65. # === DATASETS & DATALOADERS ===
  66. train_dataset = OxfordFlowersDataset(images_folder, split_file, labels_file, split='train', transform=transform)
  67. val_dataset = OxfordFlowersDataset(images_folder, split_file, labels_file, split='val', transform=transform)
  68. test_dataset = OxfordFlowersDataset(images_folder, split_file, labels_file, split='test', transform=transform)
  69.  
  70. BATCH_SIZE = 32
  71. train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
  72. val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
  73. test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
  74.  
  75. # === MODEL SETUP ===
  76. model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
  77.  
  78. for param in model.parameters():
  79. param.requires_grad = False
  80.  
  81. num_classes = 102
  82. model.fc = nn.Linear(512, num_classes)
  83.  
  84. for param in model.fc.parameters():
  85. param.requires_grad = True
  86.  
  87. model = model.to(device)
  88. criterion = nn.CrossEntropyLoss()
  89. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  90.  
  91. train_losses = []
  92. val_losses = []
  93. best_model_wts = copy.deepcopy(model.state_dict())
  94. best_val_loss = float('inf')
  95.  
  96. num_epochs = 10
  97.  
  98. def get_transform_summary(transform):
  99. if isinstance(transform, transforms.Compose):
  100. return [str(t) for t in transform.transforms]
  101. else:
  102. return [str(transform)]
  103.  
  104. transform_summary = get_transform_summary(transform)
  105.  
  106. training_log = {
  107. "augmentations": {
  108. "train": transform_summary
  109. },
  110. "epochs": []
  111. }
  112.  
  113. # At top-level before training loop
  114. best_val_f1 = 0.0
  115.  
  116. for epoch in range(num_epochs):
  117. print(f"\nEpoch {epoch + 1}/{num_epochs}")
  118. print('-' * 30)
  119.  
  120. # === TRAINING ===
  121. model.train()
  122. running_loss = 0.0
  123. all_preds = []
  124. all_targets = []
  125.  
  126. for inputs, labels in train_loader:
  127. inputs, labels = inputs.to(device), labels.to(device)
  128.  
  129. optimizer.zero_grad()
  130. outputs = model(inputs)
  131. loss = criterion(outputs, labels)
  132. loss.backward()
  133. optimizer.step()
  134.  
  135. running_loss += loss.item() * inputs.size(0)
  136. preds = torch.argmax(outputs, 1)
  137. all_preds.extend(preds.cpu().numpy())
  138. all_targets.extend(labels.cpu().numpy())
  139.  
  140. epoch_train_loss = running_loss / len(train_loader.dataset)
  141. train_losses.append(epoch_train_loss)
  142.  
  143. train_precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
  144. train_recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
  145. train_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
  146. train_accuracy = np.mean(np.array(all_preds) == np.array(all_targets))
  147.  
  148. print(f"Train Loss: {epoch_train_loss:.4f} | Accuracy: {train_accuracy:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | F1: {train_f1:.4f}")
  149.  
  150. # === VALIDATION ===
  151. model.eval()
  152. running_val_loss = 0.0
  153. correct = 0
  154. val_preds = []
  155. val_targets = []
  156.  
  157. with torch.no_grad():
  158. for inputs, labels in val_loader:
  159. inputs, labels = inputs.to(device), labels.to(device)
  160. outputs = model(inputs)
  161. loss = criterion(outputs, labels)
  162. running_val_loss += loss.item() * inputs.size(0)
  163.  
  164. preds = torch.argmax(outputs, 1)
  165. correct += (preds == labels).sum().item()
  166. val_preds.extend(preds.cpu().numpy())
  167. val_targets.extend(labels.cpu().numpy())
  168.  
  169. epoch_val_loss = running_val_loss / len(val_loader.dataset)
  170. val_losses.append(epoch_val_loss)
  171. val_acc = correct / len(val_loader.dataset)
  172. val_precision = precision_score(val_targets, val_preds, average='macro', zero_division=0)
  173. val_recall = recall_score(val_targets, val_preds, average='macro', zero_division=0)
  174. val_f1 = f1_score(val_targets, val_preds, average='macro', zero_division=0)
  175.  
  176. print(f"Val Loss: {epoch_val_loss:.4f} | Accuracy: {val_acc:.4f} | Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | F1: {val_f1:.4f}")
  177.  
  178. # === SAVE JSON METRICS ===
  179. epoch_dict = {
  180. "epoch": epoch + 1,
  181. "train_loss": epoch_train_loss,
  182. "train_accuracy": train_accuracy,
  183. "train_precision": train_precision,
  184. "train_recall": train_recall,
  185. "train_f1": train_f1,
  186. "val_loss": epoch_val_loss,
  187. "val_accuracy": val_acc,
  188. "val_precision": val_precision,
  189. "val_recall": val_recall,
  190. "val_f1": val_f1
  191. }
  192. training_log["epochs"].append(epoch_dict)
  193.  
  194. with open("logs/training_metrics.json", "w") as f:
  195. json.dump(training_log, f, indent=4)
  196.  
  197. # === SAVE PLOTS ===
  198. def save_plot(values, title, ylabel, filename):
  199. plt.figure()
  200. plt.plot(range(1, len(values) + 1), values, marker='o')
  201. plt.title(title)
  202. plt.xlabel("Epoch")
  203. plt.ylabel(ylabel)
  204. plt.grid(True)
  205. plt.savefig(f"logs/plots/{filename}")
  206. plt.close()
  207.  
  208. save_plot(train_losses, "Training Loss", "Loss", "train_loss.png")
  209. save_plot(val_losses, "Validation Loss", "Loss", "val_loss.png")
  210.  
  211. # Save other metrics per epoch
  212. save_plot([e["val_accuracy"] for e in training_log["epochs"]], "Validation Accuracy", "Accuracy", "val_accuracy.png")
  213. save_plot([e["val_precision"] for e in training_log["epochs"]], "Validation Precision", "Precision", "val_precision.png")
  214. save_plot([e["val_recall"] for e in training_log["epochs"]], "Validation Recall", "Recall", "val_recall.png")
  215. save_plot([e["val_f1"] for e in training_log["epochs"]], "Validation F1 Score", "F1 Score", "val_f1.png")
  216.  
  217. # === CHECKPOINT ===
  218. if (epoch + 1) % 5 == 0:
  219. ckpt_path = f"checkpoints/epoch_{epoch + 1}.pth"
  220. torch.save({
  221. 'epoch': epoch + 1,
  222. 'model_state_dict': model.state_dict(),
  223. 'optimizer_state_dict': optimizer.state_dict(),
  224. }, ckpt_path)
  225. print(f"💾 Saved checkpoint at {ckpt_path}")
  226.  
  227.  
  228. if val_f1 > best_val_f1:
  229. best_val_f1 = val_f1
  230. torch.save(model.state_dict(), "checkpoints/best_model.pth")
  231. print("✅ New best model found and saved!")
  232.  
  233. # Restore best weights
  234. model.load_state_dict(best_model_wts)
  235.  
  236. # Plot loss curves
  237. plt.plot(train_losses, label='Train Loss')
  238. plt.plot(val_losses, label='Val Loss')
  239. plt.xlabel('Epoch')
  240. plt.ylabel('Loss')
  241. plt.title('Training vs Validation Loss')
  242. plt.legend()
  243. plt.grid(True)
  244. plt.show()
  245.  
  246. if __name__ == '__main__':
  247. main()
  248.  
Advertisement
Add Comment
Please, Sign In to add comment