Advertisement
tryhardqaq

VGG16_practice

Apr 29th, 2023
619
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.12 KB | Source Code | 0 0
  1. # 在最後面有說明怎麼獲得資料並進行整理的,
  2.  
  3. # import
  4. import torch
  5. import torch.nn as nn
  6. from torch.utils.data import Dataset
  7. from torch.utils.data import DataLoader
  8. from torch.utils.data import random_split
  9. from torchvision import datasets, transforms
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12.  
  13. # 若 CUDA 環境可用,則使用 GPU 計算,否則使用 CPU
  14. device = "cuda" if torch.cuda.is_available() else "cpu"
  15. print(f"Using {device=}")
  16.  
  17. # 設定圖片轉換器
  18. transform = transforms.Compose([
  19.     transforms.Resize((224, 224)),  # 將圖片大小轉換為 224x224
  20.     transforms.ToTensor()  # 轉換成 PyTorch 張量
  21. ])
  22.  
  23. # 讀取資料夾中的圖片資料
  24. dataset = datasets.ImageFolder(root='pizza_classify', transform=transform)
  25.  
  26. # 將資料集分成訓練集和驗證集
  27. train_size = int(len(dataset) * 0.8)  # 設定訓練集佔 80%
  28. valid_size = len(dataset) - train_size
  29.  
  30. # 隨機分布
  31. train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
  32. # 設定資料讀取器
  33. batch_size = 16
  34. train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  35. valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
  36.  
  37. # 資料類別:
  38. print(dataset.class_to_idx)
  39.  
  40. # 訓練集、驗證集資料筆數
  41. train_dataset_num = len(train_dataset)
  42. valid_dataset_num = len(valid_dataset)
  43. print(f'{train_dataset_num=}')
  44. print(f'{valid_dataset_num=}')
  45.  
  46. # 訓練集batch、驗證集batch數(批次數 即 iterantion)
  47. print(f'{len(train_dataloader)=}')
  48. print(f'{len(valid_dataloader)=}')
  49.  
  50. # VGG16模型
  51. class VGG16(nn.Module):
  52.     def __init__(self, num_classes=2):
  53.         super(VGG16, self).__init__()
  54.         self.features = nn.Sequential(
  55.             nn.Conv2d(3, 64, kernel_size=3, padding=1), # 224
  56.             nn.BatchNorm2d(64),
  57.             nn.ReLU(inplace=True),
  58.             nn.Conv2d(64, 64, kernel_size=3, padding=1), # 224
  59.             nn.BatchNorm2d(64),
  60.             nn.ReLU(inplace=True),
  61.             nn.MaxPool2d(kernel_size=2, stride=2), # 112
  62.            
  63.             nn.Conv2d(64, 128, kernel_size=3, padding=1), # 112
  64.             nn.BatchNorm2d(128),
  65.             nn.ReLU(inplace=True),
  66.             nn.Conv2d(128, 128, kernel_size=3, padding=1), # 112
  67.             nn.BatchNorm2d(128),
  68.             nn.ReLU(inplace=True),
  69.             nn.MaxPool2d(kernel_size=2, stride=2), # 56
  70.            
  71.             nn.Conv2d(128, 256, kernel_size=3, padding=1), # 56
  72.             nn.BatchNorm2d(256),
  73.             nn.ReLU(inplace=True),
  74.             nn.Conv2d(256, 256, kernel_size=3, padding=1), # 56
  75.             nn.BatchNorm2d(256),
  76.             nn.ReLU(inplace=True),
  77.             nn.Conv2d(256, 256, kernel_size=3, padding=1), # 56
  78.             nn.BatchNorm2d(256),
  79.             nn.ReLU(inplace=True),
  80.             nn.MaxPool2d(kernel_size=2, stride=2), # 28
  81.            
  82.             nn.Conv2d(256, 512, kernel_size=3, padding=1), # 28
  83.             nn.BatchNorm2d(512),
  84.             nn.ReLU(inplace=True),
  85.             nn.Conv2d(512, 512, kernel_size=3, padding=1), # 28
  86.             nn.BatchNorm2d(512),
  87.             nn.ReLU(inplace=True),
  88.             nn.Conv2d(512, 512, kernel_size=3, padding=1), # 28
  89.             nn.BatchNorm2d(512),
  90.             nn.ReLU(inplace=True),
  91.             nn.MaxPool2d(kernel_size=2, stride=2), # 14
  92.            
  93.             nn.Conv2d(512, 512, kernel_size=3, padding=1), # 14
  94.             nn.BatchNorm2d(512),
  95.             nn.ReLU(inplace=True),
  96.             nn.Conv2d(512, 512, kernel_size=3, padding=1), # 14
  97.             nn.BatchNorm2d(512),
  98.             nn.ReLU(inplace=True),
  99.             nn.Conv2d(512, 512, kernel_size=3, padding=1), # 14
  100.             nn.BatchNorm2d(512),
  101.             nn.ReLU(inplace=True),
  102.             nn.MaxPool2d(kernel_size=2, stride=2) # 7
  103.         )
  104.         self.classifier = nn.Sequential(
  105.             nn.Linear(25088, 4096),
  106.             nn.ReLU(inplace=True),
  107.             nn.Dropout(p=0.5),
  108.             nn.Linear(4096, 4096),
  109.             nn.ReLU(inplace=True),
  110.             nn.Dropout(p=0.5),
  111.             nn.Linear(4096, num_classes)
  112.         )
  113.  
  114.     def forward(self, x):
  115.         output = self.features(x)
  116.         output = output.view(output.size(0),-1)
  117.         output = self.classifier(output)
  118.         output = torch.softmax(output, dim=1)
  119.         return output
  120.  
  121. # 設置停止器
  122. class Stop_and_Save:
  123.     def __init__(self, patience = 10, path = 'vgg16_pizza.pth'):
  124.         self.patience = patience
  125.         self.path = path
  126.         self.best_val = None
  127.         self.counter = 0
  128.         self.earlystop = False
  129.         self.min_val = np.inf
  130.    
  131.     def __call__(self,val_loss,model):
  132.         if self.best_val is None:
  133.             self.best_val = val_loss
  134.         elif val_loss >= self.best_val:
  135.             self.counter += 1
  136.             if self.counter == self.patience:
  137.                 self.earlystop = True
  138.         else:
  139.             self.best_val = val_loss
  140.             if val_loss <= self.min_val:
  141.                 self.save_checkpoint(model)
  142.                 self.min_val = val_loss
  143.             self.counter = 0
  144.  
  145.     def save_checkpoint(self, model):
  146.             torch.save(model.state_dict(), self.path)
  147.  
  148. # 建立模型並輸出模型資訊
  149. model = VGG16(num_classes=2)
  150. # print(model)
  151.  
  152. # 將model移至模型
  153. model.to(device)
  154.  
  155. # 設定損失函數、優化器和學習率
  156. criterion = nn.CrossEntropyLoss()
  157.  
  158. # VGG因為參數過多,不能使用ADAM
  159. # 我也不知道這是不是真的,但總之改成SGD之後loss有在好好變動
  160. optimizer = torch.optim.SGD(model.parameters(),lr=0.005,momentum=0.9)
  161. # 用調度器適當縮小learning rate
  162. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.2,patience=5,mode='min',verbose=False)
  163.  
  164. # 具體化停止器
  165. early_stopping = Stop_and_Save(patience=10)
  166.  
  167. # 設定最大epochs和訓練過程紀錄
  168. max_epochs = 200
  169. # train的部分:
  170. train_loss_history = []
  171. train_accuracy_history = []
  172. train_precision_history = []
  173. # valid的部分:
  174. valid_loss_history = []
  175. valid_accuracy_history = []
  176. valid_precision_history = []
  177.  
  178. # 訓練迴圈
  179. # 把數據移往gpu必須在迴圈中執行是因為GPU空間太小,最好用mini_batch的方式送入
  180.  
  181. for epoch in range(max_epochs):
  182.     # 每個epoch開始時,需要歸零的計數器
  183.     train_loss_total = 0.0
  184.     correct_train_total = 0
  185.     precision_train_total = 0
  186.     guess_train_total = 0
  187.     train_iteration = 0
  188.  
  189.     valid_loss_total = 0.0
  190.     correct_valid_total = 0
  191.     precision_valid_total = 0
  192.     guess_valid_total = 0
  193.     valid_iteration = 0
  194.  
  195.     model.train()
  196.     for i,(inputs,labels) in enumerate(train_dataloader):
  197.         # 將資料移到GPU上
  198.         inputs, labels = inputs.to(device), labels.to(device)
  199.         # 清空優化器的梯度
  200.         optimizer.zero_grad()
  201.         # 模型向前計算
  202.         train_outputs = model(inputs)
  203.         # 計算損失
  204.         train_loss = criterion(train_outputs, labels)
  205.         # 反向傳播
  206.         train_loss.backward()
  207.         # 用優化器更新權重
  208.         optimizer.step()
  209.  
  210.         # 數據紀錄(accuracy和precision必須用累加否則會有剛好全部都選到非披薩的可能)
  211.         # train_loss累加
  212.         train_loss_total += train_loss.item()
  213.         # 獲取預測標籤
  214.         _, preds = torch.max(train_outputs,1)
  215.         # 訓練正確的次數
  216.         correct_train_total += ((preds == labels).sum()).item()
  217.         # 猜pizza且真的為pizza的次數
  218.         precision_train_total += ((torch.logical_and(preds==labels,preds==1)).sum()).item()
  219.         # 猜pizza為真
  220.         guess_train_total += ((preds == 1).sum()).item()
  221.  
  222.         train_iteration += 1
  223.  
  224.     model.eval()
  225.     for i,(inputs,labels) in enumerate(valid_dataloader):
  226.         # 即使在eval模式下,pytorch仍會跟蹤梯度占用內存,所以仍要使用no_grad
  227.         with torch.no_grad():
  228.             # 將資料移到GPU上
  229.             inputs, labels = inputs.to(device), labels.to(device)
  230.             valid_outputs = model(inputs)
  231.             # 計算損失
  232.             valid_loss = criterion(valid_outputs,labels)
  233.  
  234.             # 數據紀錄(同上)
  235.             # valid_loss累加
  236.             valid_loss_total += valid_loss.item()
  237.             # 獲取預測標籤
  238.             _, preds = torch.max(valid_outputs,1)
  239.             # 驗證正確的次數
  240.             correct_valid_total += ((preds == labels).sum()).item()
  241.             # 猜pizza且真的為pizza的次數(驗證)
  242.             precision_valid_total += ((torch.logical_and(preds==labels,preds==1)).sum()).item()
  243.             guess_valid_total += ((preds == 1).sum()).item()
  244.             valid_iteration += 1
  245.            
  246.  
  247.     # 每個epoch報告一次
  248.     # train的部分:
  249.     train_loss_avg = train_loss_total/train_iteration
  250.     train_loss_history.append(train_loss_avg)
  251.  
  252.     train_accuracy = correct_train_total*100/train_dataset_num
  253.     train_accuracy_history.append(train_accuracy)
  254.  
  255.     train_prec = precision_train_total*100/guess_train_total
  256.     train_precision_history.append(train_prec)
  257.     # valid的部分:
  258.     valid_loss_avg = valid_loss_total/valid_iteration
  259.     valid_loss_history.append(valid_loss_avg)
  260.  
  261.     valid_accuracy = correct_valid_total*100/valid_dataset_num
  262.     valid_accuracy_history.append(valid_accuracy)
  263.  
  264.     valid_prec = precision_valid_total*100/guess_valid_total
  265.     valid_precision_history.append(valid_prec)
  266.  
  267.  
  268.     # 使用調度器
  269.     scheduler.step(valid_loss_avg)
  270.  
  271.     # 使用停止器
  272.     early_stopping(valid_loss_avg,model=model)
  273.     if early_stopping.earlystop:
  274.         print(f"early stop triggered by {valid_loss_avg=}! at epochs:{epoch}")
  275.         break
  276.    
  277.     # if epoch%1 == 0:
  278.     #     print(f'epoch:{epoch}')
  279.     #     print(f'{train_loss_avg=} ; {train_accuracy=}% ; {train_prec=}%')
  280.     #     print(f'{valid_loss_avg=} ; {valid_accuracy=}% ; {valid_prec=}%')
  281.  
  282. # 輸出訓練歷史
  283. plt.plot(train_loss_history, label='train', color = 'deepskyblue')
  284. plt.plot(valid_loss_history, label='valid', color = 'r')
  285. plt.title('Loss trend')
  286. plt.xlabel('epochs')
  287. plt.ylabel('Loss')
  288. plt.legend()
  289. plt.show()
  290.  
  291. plt.plot(train_accuracy_history, label='train', color = 'deepskyblue')
  292. plt.plot(valid_accuracy_history, label='valid', color = 'r')
  293. plt.title('accuracy')
  294. plt.xlabel('epochs')
  295. plt.ylabel('accuracy')
  296. plt.legend()
  297. plt.show()
  298.  
  299. plt.plot(train_precision_history, label='train', color = 'deepskyblue')
  300. plt.plot(valid_precision_history, label='valid', color = 'r')
  301. plt.title('precision')
  302. plt.xlabel('epochs')
  303. plt.ylabel('precision')
  304. plt.legend()
  305. plt.show()
  306.  
  307.  
  308. # data
  309. # by food-101, I choose 10 random foods from each foods category
  310. code that can help you to randomly get 1000 non_pizza foods in food_101
  311. # remember to delete/move pizza folder from food_101 before use this
  312.  
  313. # all import needed is here
  314. import os
  315. import numpy as np
  316. import shutil
  317.  
  318. # get information about pizza dataset
  319. pizza_Path = "pizza_classify"
  320. category_list = os.listdir(pizza_Path)
  321. print(category_list)
  322. pizza_list = os.listdir(pizza_Path+"/pizza")
  323. print(pizza_list)
  324. pizzas = len(pizza_list)
  325. print(pizzas)
  326.  
  327. # first calculate number
  328. img_Path = "food_101/images"
  329. food_list = os.listdir(img_Path)
  330. foods = len(food_list)
  331. print(f'{foods=}')
  332. divide_foods = pizzas//foods
  333. remain_foods = pizzas - foods*divide_foods
  334. print(f'{divide_foods=} and {remain_foods}')
  335.  
  336. # make dir
  337. if not os.path.exists(pizza_Path + "/not_pizza"):
  338.     os.mkdir(pizza_Path + "/not_pizza")
  339. counts = 0
  340.  
  341. for i in range(foods):
  342.     pick_food_path = food_list[i]
  343.     pick_food_path = img_Path + "/" + pick_food_path
  344.     pick_food_list = os.listdir(pick_food_path)
  345.     pick_food_number = len(pick_food_list)
  346.     # print(f'{pick_food_path} have {pick_food_number} picture in it')
  347.     # print(pick_food_list)
  348.     pick_food_list = np.array(pick_food_list)  # necessary
  349.     # print(pick_food_number,divide_foods)
  350.     pick_foods = pick_food_list[np.random.choice(pick_food_number,divide_foods,replace=False)]
  351.     # print(pick_foods)
  352.     for j in range(divide_foods):
  353.         shutil.copyfile(f'{pick_food_path}/{pick_foods[j]}',f'{pizza_Path}/not_pizza/not_pizza_{counts:04d}.jpg')
  354.         counts += 1
  355. print(f'{counts=}')
Tags: VGG16 food 101
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement