Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- if __name__ == "__main__":
- for epoch in range(START_EPOCH, START_EPOCH+hp.epoch):
- adjust_learning_rate(optimizer,epoch)
- train(epoch, hp.wrong_save)
- #mining(epoch)
- valid(epoch)
- RuntimeError: CUDA out of memory. Tried to allocate 74.12 MiB (GPU 0; 5.93 GiB total capacity; 4.73 GiB already allocated; 75.06 MiB free; 19.57 MiB cached)
- def train(epoch, wrong_save=False):
- ''' trian net using patches of slide.
- save csv file that has patch file name predicted incorrectly.
- Args:
- epoch (int): current epoch
- wrong_save (bool): If True, save the csv file that has patch file name
- predicted incorrectly
- '''
- print('nEpoch: %d' % epoch)
- net.train()
- train_loss = 0
- correct = 0
- total = 0
- wrong_list = []
- for batch_idx, (inputs, targets, filename) in enumerate(trainloader):
- if USE_CUDA:
- inputs = inputs.cuda()
- targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
- optimizer.zero_grad()
- inputs, targets = Variable(inputs), Variable(targets)
- outputs = net(inputs)
- outputs = torch.squeeze(outputs)
- loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
- train_loss += loss.item()
- total += targets.size(0)
- batch_size = targets.shape[0]
- outputs += Variable((torch.ones(batch_size) * (THRESHOLD)).cuda())
- outputs = torch.floor(outputs)
- correct += outputs.data.eq(targets.data).cpu().sum()
- filename_list = filename
- if wrong_save == True:
- for idx in range(len(filename_list)):
- if outputs.data[idx] != targets.data[idx]:
- wrong_name = filename_list[idx]
- wrong_list.append(wrong_name)
- progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
- % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
- if wrong_save == True:
- wrong_csv = open(cf.wrong_path+'wrong_data_epoch'+str(epoch)+'.csv','w',encoding='utf-8')
- wr = csv.writer(wrong_csv)
- for name in wrong_list:
- wr.writerow([name])
- wrong_csv.close()
- CUR_TRA_ACC.append(100.*correct/total)
- def valid(epoch):
- ''' valid net using patches of slide.
- Save checkpoint if AUC score is higher than saved checkpoint's.
- Args:
- epoch (int): current epoch
- '''
- global BEST_AUC
- global THRESHOLD
- global LR_CHANCE
- global CK_CHANCE
- global LR_DECAY
- net.eval()
- valid_loss = 0
- total = 0
- correct = 0
- outputs_list = np.array([])
- targets_list = np.array([])
- for batch_idx, (inputs, targets) in enumerate(valloader):
- if USE_CUDA:
- inputs = inputs.cuda()
- targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
- batch_size = targets.shape[0]
- inputs, targets = Variable(inputs, volatile=True), Variable(targets)
- outputs = net(inputs)
- total += targets.size(0)
- outputs = torch.squeeze(outputs)
- loss = criterion(outputs, targets)
- valid_loss += loss.item()
- _outputs = np.array(outputs.data.cpu()).astype(float)
- _targets = np.array(targets.data.cpu()).astype(float)
- outputs_list = np.append(outputs_list, _outputs)
- targets_list = np.append(targets_list, _targets)
- outputs += Variable((torch.ones(batch_size) * (1-THRESHOLD)).cuda())
- outputs = torch.floor(outputs)
- correct += int(outputs.eq(targets).cpu().sum())
- progress_bar(batch_idx, len(valloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
- % (valid_loss/(batch_idx+1), 100.*correct/total, correct, total))
- print(tp, tn, fp, fn)
- correct, tp, tn, fp, fn, recall, precision, specificity, f1_score, auc, threshold = stats(outputs_list, targets_list)
- acc = correct/total
- THRESHOLD = threshold
- print(tp, tn, fp, fn)
- print('Acc: %.3f, Recall: %.3f, Prec: %.3f, Spec: %.3f, F1: %.3f, Thres: %.3f, AUC: %.3f'
- %(acc, recall, precision, specificity, f1_score, threshold, auc))
- print('%17s %12sn%-11s %-8d %-8dn%-11s %-8d %-8d'
- %('Tumor', 'Normal','pos',tp,fp,'neg',fn,tn))
- print("lr: ",args.lr * (0.5 ** (LR_DECAY)), "lr chance:",LR_CHANCE)
- # plot data
- CUR_EPOCH.append(epoch)
- CUR_VAL_ACC.append(acc)
- CUR_LOSS.append(valid_loss/(batch_idx+1))
- CUR_LR.append(args.lr * (0.5 ** (LR_DECAY)))
- # Save checkpoint.
- if auc > BEST_AUC:
- print('saving...')
- BEST_AUC = auc
- state = {
- 'net': net if USE_CUDA else net,
- 'acc': acc,
- 'loss': valid_loss,
- 'recall': recall,
- 'specificity': specificity,
- 'precision': precision,
- 'f1_score': f1_score,
- 'auc': auc,
- 'epoch': epoch,
- 'lr': args.lr * (0.5**(LR_DECAY)),
- 'threshold': threshold
- }
- torch.save(state, './checkpoint/ckpt.t7')
- def get_dataset(train_transform, test_transform, train_max,
- val_max, subtest_max, ratio=0, mining_mode=False):
- ''' dataset function to get train, valid, subtest, test, mining dataset
- Args:
- train_transform (torchvision.transforms): train set transform for data argumentation
- test_transform (torchvision.transfroms): test set transform for data argumentation
- train_max (int): limit of trian set
- val_max (int): limit of validation set
- subtest_max (int): limit of subtest set
- ratio (int): for mining_mode, inclusion ratio of train set compared mining set
- mining_mode (bool): If true, return mining dataset
- '''
- train_dataset = camel(cf.dataset_path + 'train/', usage='train',
- limit = train_max, transform=train_transform)
- val_dataset = camel(cf.dataset_path + 'validation/', usage='val',
- limit = val_max, transform=test_transform)
- subtest_dataset = camel(cf.dataset_path + 'test/', usage='subtest',
- limit = subtest_max, transform=test_transform)
- test_dataset = camel(cf.test_path, usage ='test',transform=test_transform)
- if mining_mode == True:
- mining_dataset = camel(cf.dataset_path + 'mining/', usage='mining',
- train_ratio = ratio, transform=train_transform)
- return train_dataset, val_dataset, subtest_dataset, test_dataset, mining_dataset
- else:
- return train_dataset, val_dataset, subtest_dataset, test_dataset
Add Comment
Please, Sign In to add comment