Advertisement
canbolat

deit code

Dec 19th, 2024
37
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.99 KB | Source Code | 0 0
  1. from sklearn.metrics import precision_score, recall_score, f1_score
  2. import os
  3. import numpy as np
  4. from PIL import Image
  5. from torch.utils.data import Dataset
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. import torch.nn.functional as F
  10. from torch.utils.data import DataLoader
  11. #import torchvision.transforms as transforms
  12. from torchvision import transforms
  13. import torchvision
  14. from transformers import DeiTForImageClassificationWithTeacher, DeiTImageProcessor
  15.  
  16. transform = transforms.Compose([
  17. transforms.Grayscale(num_output_channels=3),
  18. transforms.ToTensor(),
  19. ])
  20.  
  21. class myDataset(Dataset):
  22. def __init__(self, root_dir):
  23. self.root_dir = root_dir
  24. self.data = []
  25.  
  26. for label in os.listdir(root_dir):
  27. label_dir = os.path.join(root_dir, label)
  28. if os.path.isdir(label_dir):
  29. for file in os.listdir(label_dir):
  30. self.data.append((os.path.join(label_dir, file), int(label)))
  31.  
  32. def __len__(self):
  33. return len(self.data)
  34.  
  35. def __getitem__(self, idx):
  36. img_path, label = self.data[idx]
  37. image = Image.open(img_path)
  38. #print(f'image before normalization: {image}') #DEBUG
  39. image = transform(image)
  40. #print(f'image after normalization to 0-1{image}') #DEBUG
  41. image_np = np.array(image)
  42. return image, label
  43.  
  44. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  45.  
  46. learning_rate = 0.01
  47. batch_size = 32
  48. num_epochs = 32
  49.  
  50. model_path = "/content/drive/MyDrive/datasets/PyTorchdeit-base-distilled-patch16-384/"
  51. model = DeiTForImageClassificationWithTeacher.from_pretrained(model_path)
  52. model.to(device)
  53.  
  54. train_dataset = myDataset(root_dir="/content/drive/MyDrive/datasets/KneeOsteoarthritisXray/train")
  55. #val_dataset = myDataset(root_dir="/content/drive/MyDrive/datasets/KneeOsteoarthritisXray/val")
  56. test_dataset = myDataset(root_dir="/content/drive/MyDrive/datasets/KneeOsteoarthritisXray/test")
  57.  
  58. train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  59. #val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)
  60. test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
  61.  
  62.  
  63. criterion = nn.CrossEntropyLoss()
  64. optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  65.  
  66. for epoch in range(num_epochs):
  67. losses = []
  68.  
  69. for batch_idx, (data, targets) in enumerate(train_loader):
  70. data = data.to(device=device)
  71. targets = targets.to(device=device)
  72.  
  73. scores = model(data)['logits']
  74. loss = criterion(scores, targets)
  75. losses.append(loss.item())
  76.  
  77. optimizer.zero_grad()
  78. loss.backward()
  79.  
  80. optimizer.step()
  81.  
  82. print(f'Cost at epoch {epoch} is {(sum(losses)/len(losses))}')
  83.  
  84.  
  85. def check_accuracy(loader, model):
  86. print("Checking accuracy")
  87. num_correct = 0
  88. num_samples = 0
  89. all_labels = []
  90. all_preds = []
  91.  
  92. model.eval()
  93.  
  94. with torch.no_grad():
  95. for x, y in loader:
  96. x = x.to(device=device)
  97. y = y.to(device=device)
  98.  
  99. scores = model(x)['logits']
  100. print(f'scores: {scores}')
  101. _, predictions = scores.max(1)
  102. print(f'predictions: {predictions}')
  103. print(f'y: {y}')
  104.  
  105. all_labels.extend(y.cpu().numpy())
  106. all_preds.extend(predictions.cpu().numpy())
  107.  
  108. num_correct += (predictions == y).sum() #.item()
  109. num_samples += predictions.size(0)
  110.  
  111. print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')
  112.  
  113. precision = precision_score(all_labels, all_preds, average='weighted')
  114. recall = recall_score(all_labels, all_preds, average='weighted')
  115. f1 = f1_score(all_labels, all_preds, average='weighted')
  116.  
  117. print(f'Precision: {precision:.4f}')
  118. print(f'Recall: {recall:.4f}')
  119. print(f'F1-Score: {f1:.4f}')
  120.  
  121. model.train()
  122.  
  123. print('test set accuracy')
  124. check_accuracy(test_loader, model)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement