Guest User

Untitled

a guest
Jan 21st, 2019
166
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.53 KB | None | 0 0
  1. if __name__ == "__main__":
  2. for epoch in range(START_EPOCH, START_EPOCH+hp.epoch):
  3. adjust_learning_rate(optimizer,epoch)
  4. train(epoch, hp.wrong_save)
  5. #mining(epoch)
  6. valid(epoch)
  7.  
  8. RuntimeError: CUDA out of memory. Tried to allocate 74.12 MiB (GPU 0; 5.93 GiB total capacity; 4.73 GiB already allocated; 75.06 MiB free; 19.57 MiB cached)
  9.  
  10. def train(epoch, wrong_save=False):
  11. ''' trian net using patches of slide.
  12. save csv file that has patch file name predicted incorrectly.
  13.  
  14. Args:
  15. epoch (int): current epoch
  16. wrong_save (bool): If True, save the csv file that has patch file name
  17. predicted incorrectly
  18. '''
  19.  
  20. print('nEpoch: %d' % epoch)
  21.  
  22. net.train()
  23. train_loss = 0
  24. correct = 0
  25. total = 0
  26. wrong_list = []
  27.  
  28. for batch_idx, (inputs, targets, filename) in enumerate(trainloader):
  29. if USE_CUDA:
  30. inputs = inputs.cuda()
  31. targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
  32.  
  33. optimizer.zero_grad()
  34. inputs, targets = Variable(inputs), Variable(targets)
  35. outputs = net(inputs)
  36. outputs = torch.squeeze(outputs)
  37. loss = criterion(outputs, targets)
  38. loss.backward()
  39. optimizer.step()
  40. train_loss += loss.item()
  41. total += targets.size(0)
  42. batch_size = targets.shape[0]
  43.  
  44. outputs += Variable((torch.ones(batch_size) * (THRESHOLD)).cuda())
  45. outputs = torch.floor(outputs)
  46. correct += outputs.data.eq(targets.data).cpu().sum()
  47. filename_list = filename
  48.  
  49. if wrong_save == True:
  50. for idx in range(len(filename_list)):
  51. if outputs.data[idx] != targets.data[idx]:
  52. wrong_name = filename_list[idx]
  53. wrong_list.append(wrong_name)
  54.  
  55. progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
  56. % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
  57.  
  58. if wrong_save == True:
  59. wrong_csv = open(cf.wrong_path+'wrong_data_epoch'+str(epoch)+'.csv','w',encoding='utf-8')
  60. wr = csv.writer(wrong_csv)
  61. for name in wrong_list:
  62. wr.writerow([name])
  63. wrong_csv.close()
  64.  
  65. CUR_TRA_ACC.append(100.*correct/total)
  66.  
  67.  
  68.  
  69. def valid(epoch):
  70. ''' valid net using patches of slide.
  71. Save checkpoint if AUC score is higher than saved checkpoint's.
  72.  
  73. Args:
  74. epoch (int): current epoch
  75. '''
  76.  
  77. global BEST_AUC
  78. global THRESHOLD
  79. global LR_CHANCE
  80. global CK_CHANCE
  81. global LR_DECAY
  82.  
  83. net.eval()
  84. valid_loss = 0
  85. total = 0
  86. correct = 0
  87.  
  88. outputs_list = np.array([])
  89. targets_list = np.array([])
  90.  
  91. for batch_idx, (inputs, targets) in enumerate(valloader):
  92. if USE_CUDA:
  93. inputs = inputs.cuda()
  94. targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
  95.  
  96. batch_size = targets.shape[0]
  97. inputs, targets = Variable(inputs, volatile=True), Variable(targets)
  98. outputs = net(inputs)
  99. total += targets.size(0)
  100. outputs = torch.squeeze(outputs)
  101. loss = criterion(outputs, targets)
  102. valid_loss += loss.item()
  103.  
  104. _outputs = np.array(outputs.data.cpu()).astype(float)
  105. _targets = np.array(targets.data.cpu()).astype(float)
  106. outputs_list = np.append(outputs_list, _outputs)
  107. targets_list = np.append(targets_list, _targets)
  108.  
  109. outputs += Variable((torch.ones(batch_size) * (1-THRESHOLD)).cuda())
  110. outputs = torch.floor(outputs)
  111. correct += int(outputs.eq(targets).cpu().sum())
  112.  
  113. progress_bar(batch_idx, len(valloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
  114. % (valid_loss/(batch_idx+1), 100.*correct/total, correct, total))
  115.  
  116. print(tp, tn, fp, fn)
  117. correct, tp, tn, fp, fn, recall, precision, specificity, f1_score, auc, threshold = stats(outputs_list, targets_list)
  118. acc = correct/total
  119. THRESHOLD = threshold
  120.  
  121. print(tp, tn, fp, fn)
  122.  
  123. print('Acc: %.3f, Recall: %.3f, Prec: %.3f, Spec: %.3f, F1: %.3f, Thres: %.3f, AUC: %.3f'
  124. %(acc, recall, precision, specificity, f1_score, threshold, auc))
  125. print('%17s %12sn%-11s %-8d %-8dn%-11s %-8d %-8d'
  126. %('Tumor', 'Normal','pos',tp,fp,'neg',fn,tn))
  127. print("lr: ",args.lr * (0.5 ** (LR_DECAY)), "lr chance:",LR_CHANCE)
  128.  
  129. # plot data
  130. CUR_EPOCH.append(epoch)
  131. CUR_VAL_ACC.append(acc)
  132. CUR_LOSS.append(valid_loss/(batch_idx+1))
  133. CUR_LR.append(args.lr * (0.5 ** (LR_DECAY)))
  134.  
  135. # Save checkpoint.
  136. if auc > BEST_AUC:
  137. print('saving...')
  138. BEST_AUC = auc
  139. state = {
  140. 'net': net if USE_CUDA else net,
  141. 'acc': acc,
  142. 'loss': valid_loss,
  143. 'recall': recall,
  144. 'specificity': specificity,
  145. 'precision': precision,
  146. 'f1_score': f1_score,
  147. 'auc': auc,
  148. 'epoch': epoch,
  149. 'lr': args.lr * (0.5**(LR_DECAY)),
  150. 'threshold': threshold
  151. }
  152. torch.save(state, './checkpoint/ckpt.t7')
  153.  
  154. def get_dataset(train_transform, test_transform, train_max,
  155. val_max, subtest_max, ratio=0, mining_mode=False):
  156. ''' dataset function to get train, valid, subtest, test, mining dataset
  157.  
  158. Args:
  159. train_transform (torchvision.transforms): train set transform for data argumentation
  160. test_transform (torchvision.transfroms): test set transform for data argumentation
  161. train_max (int): limit of trian set
  162. val_max (int): limit of validation set
  163. subtest_max (int): limit of subtest set
  164. ratio (int): for mining_mode, inclusion ratio of train set compared mining set
  165. mining_mode (bool): If true, return mining dataset
  166. '''
  167. train_dataset = camel(cf.dataset_path + 'train/', usage='train',
  168. limit = train_max, transform=train_transform)
  169.  
  170. val_dataset = camel(cf.dataset_path + 'validation/', usage='val',
  171. limit = val_max, transform=test_transform)
  172.  
  173. subtest_dataset = camel(cf.dataset_path + 'test/', usage='subtest',
  174. limit = subtest_max, transform=test_transform)
  175.  
  176. test_dataset = camel(cf.test_path, usage ='test',transform=test_transform)
  177.  
  178. if mining_mode == True:
  179. mining_dataset = camel(cf.dataset_path + 'mining/', usage='mining',
  180. train_ratio = ratio, transform=train_transform)
  181. return train_dataset, val_dataset, subtest_dataset, test_dataset, mining_dataset
  182.  
  183. else:
  184. return train_dataset, val_dataset, subtest_dataset, test_dataset
Add Comment
Please, Sign In to add comment