Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms, models
- from torchvision.transforms import functional as TF
- from PIL import Image
- import scipy.io
- import os
- import matplotlib.pyplot as plt
- import torch.nn as nn
- import copy
- import json
- import numpy as np
- from sklearn.metrics import precision_score, recall_score, f1_score
- import seaborn as sns
- # === CUSTOM DATASET CLASS ===
- class OxfordFlowersDataset(Dataset):
- def __init__(self, images_folder, split_file, labels_file, split='train', transform=None):
- self.images_folder = images_folder
- self.labels = scipy.io.loadmat(labels_file)['labels'][0]
- self.transform = transform
- splits = scipy.io.loadmat(split_file)
- if split == 'train':
- self.indices = splits['trnid'][0]
- elif split == 'val':
- self.indices = splits['valid'][0]
- elif split == 'test':
- self.indices = splits['tstid'][0]
- else:
- raise ValueError(f"Unknown split: {split}")
- def __len__(self):
- return len(self.indices)
- def __getitem__(self, idx):
- real_idx = self.indices[idx]
- img_path = os.path.join(self.images_folder, f'image_{real_idx:05d}.jpg')
- image = Image.open(img_path).convert('RGB')
- label = self.labels[real_idx - 1] - 1 # Convert 1-based to 0-based
- if self.transform:
- image = self.transform(image)
- return image, label
- def main():
- # === SET DEVICE ===
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- os.makedirs("logs/plots", exist_ok=True)
- os.makedirs("checkpoints", exist_ok=True)
- # === TRANSFORMATIONS ===
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
- ])
- # === PATH CONFIG ===
- images_folder = 'dataset/102flowers'
- labels_file = 'dataset/imagelabels.mat'
- split_file = 'setid.mat'
- # === DATASETS & DATALOADERS ===
- train_dataset = OxfordFlowersDataset(images_folder, split_file, labels_file, split='train', transform=transform)
- val_dataset = OxfordFlowersDataset(images_folder, split_file, labels_file, split='val', transform=transform)
- test_dataset = OxfordFlowersDataset(images_folder, split_file, labels_file, split='test', transform=transform)
- BATCH_SIZE = 32
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
- val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
- test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
- # === MODEL SETUP ===
- model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
- for param in model.parameters():
- param.requires_grad = False
- num_classes = 102
- model.fc = nn.Linear(512, num_classes)
- for param in model.fc.parameters():
- param.requires_grad = True
- model = model.to(device)
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- train_losses = []
- val_losses = []
- best_model_wts = copy.deepcopy(model.state_dict())
- best_val_loss = float('inf')
- num_epochs = 10
- def get_transform_summary(transform):
- if isinstance(transform, transforms.Compose):
- return [str(t) for t in transform.transforms]
- else:
- return [str(transform)]
- transform_summary = get_transform_summary(transform)
- training_log = {
- "augmentations": {
- "train": transform_summary
- },
- "epochs": []
- }
- # At top-level before training loop
- best_val_f1 = 0.0
- for epoch in range(num_epochs):
- print(f"\nEpoch {epoch + 1}/{num_epochs}")
- print('-' * 30)
- # === TRAINING ===
- model.train()
- running_loss = 0.0
- all_preds = []
- all_targets = []
- for inputs, labels in train_loader:
- inputs, labels = inputs.to(device), labels.to(device)
- optimizer.zero_grad()
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.item() * inputs.size(0)
- preds = torch.argmax(outputs, 1)
- all_preds.extend(preds.cpu().numpy())
- all_targets.extend(labels.cpu().numpy())
- epoch_train_loss = running_loss / len(train_loader.dataset)
- train_losses.append(epoch_train_loss)
- train_precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
- train_recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
- train_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
- train_accuracy = np.mean(np.array(all_preds) == np.array(all_targets))
- print(f"Train Loss: {epoch_train_loss:.4f} | Accuracy: {train_accuracy:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | F1: {train_f1:.4f}")
- # === VALIDATION ===
- model.eval()
- running_val_loss = 0.0
- correct = 0
- val_preds = []
- val_targets = []
- with torch.no_grad():
- for inputs, labels in val_loader:
- inputs, labels = inputs.to(device), labels.to(device)
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- running_val_loss += loss.item() * inputs.size(0)
- preds = torch.argmax(outputs, 1)
- correct += (preds == labels).sum().item()
- val_preds.extend(preds.cpu().numpy())
- val_targets.extend(labels.cpu().numpy())
- epoch_val_loss = running_val_loss / len(val_loader.dataset)
- val_losses.append(epoch_val_loss)
- val_acc = correct / len(val_loader.dataset)
- val_precision = precision_score(val_targets, val_preds, average='macro', zero_division=0)
- val_recall = recall_score(val_targets, val_preds, average='macro', zero_division=0)
- val_f1 = f1_score(val_targets, val_preds, average='macro', zero_division=0)
- print(f"Val Loss: {epoch_val_loss:.4f} | Accuracy: {val_acc:.4f} | Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | F1: {val_f1:.4f}")
- # === SAVE JSON METRICS ===
- epoch_dict = {
- "epoch": epoch + 1,
- "train_loss": epoch_train_loss,
- "train_accuracy": train_accuracy,
- "train_precision": train_precision,
- "train_recall": train_recall,
- "train_f1": train_f1,
- "val_loss": epoch_val_loss,
- "val_accuracy": val_acc,
- "val_precision": val_precision,
- "val_recall": val_recall,
- "val_f1": val_f1
- }
- training_log["epochs"].append(epoch_dict)
- with open("logs/training_metrics.json", "w") as f:
- json.dump(training_log, f, indent=4)
- # === SAVE PLOTS ===
- def save_plot(values, title, ylabel, filename):
- plt.figure()
- plt.plot(range(1, len(values) + 1), values, marker='o')
- plt.title(title)
- plt.xlabel("Epoch")
- plt.ylabel(ylabel)
- plt.grid(True)
- plt.savefig(f"logs/plots/{filename}")
- plt.close()
- save_plot(train_losses, "Training Loss", "Loss", "train_loss.png")
- save_plot(val_losses, "Validation Loss", "Loss", "val_loss.png")
- # Save other metrics per epoch
- save_plot([e["val_accuracy"] for e in training_log["epochs"]], "Validation Accuracy", "Accuracy", "val_accuracy.png")
- save_plot([e["val_precision"] for e in training_log["epochs"]], "Validation Precision", "Precision", "val_precision.png")
- save_plot([e["val_recall"] for e in training_log["epochs"]], "Validation Recall", "Recall", "val_recall.png")
- save_plot([e["val_f1"] for e in training_log["epochs"]], "Validation F1 Score", "F1 Score", "val_f1.png")
- # === CHECKPOINT ===
- if (epoch + 1) % 5 == 0:
- ckpt_path = f"checkpoints/epoch_{epoch + 1}.pth"
- torch.save({
- 'epoch': epoch + 1,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- }, ckpt_path)
- print(f"💾 Saved checkpoint at {ckpt_path}")
- if val_f1 > best_val_f1:
- best_val_f1 = val_f1
- torch.save(model.state_dict(), "checkpoints/best_model.pth")
- print("✅ New best model found and saved!")
- # Restore best weights
- model.load_state_dict(best_model_wts)
- # Plot loss curves
- plt.plot(train_losses, label='Train Loss')
- plt.plot(val_losses, label='Val Loss')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.title('Training vs Validation Loss')
- plt.legend()
- plt.grid(True)
- plt.show()
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment