Advertisement
Guest User

Untitled

a guest
Feb 27th, 2017
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.64 KB | None | 0 0
  1. def ss_training(epoch):
  2. model.train()
  3.  
  4. global ss_train_loader
  5. global ss_unlabeled_loader
  6. global ss_train
  7. global ss_unlabeled
  8.  
  9. #Alternate training on real data and data we're not sure about
  10. if epoch % 2 == 0:
  11. epoch_loader = ss_train_loader
  12. else:
  13. epoch_loader = train_labeled_loader
  14.  
  15. for batch_idx, (data, target) in enumerate(epoch_loader):
  16.  
  17. data, target = Variable(data), Variable(target)
  18. optimizer.zero_grad()
  19. output = model(data)
  20. loss = F.nll_loss(output, target)
  21. loss.backward()
  22. optimizer.step()
  23. if batch_idx % 10 == 0:
  24. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  25. epoch, batch_idx * len(data), len(epoch_loader.dataset),
  26. 100. * batch_idx / len(epoch_loader), loss.data[0]))
  27.  
  28. #Evaluate on unlabeled
  29. model.eval()
  30.  
  31. confident = torch.ByteTensor()
  32.  
  33. for data, target in ss_unlabeled_loader:
  34.  
  35. #Run the model on the unlabeled data
  36. data, target = Variable(data, volatile=True), Variable(target)
  37. output=model(data)
  38.  
  39. #Create a list indicating the labels that we are 95% confident about
  40. batch_results = torch.stack((output.data.max(1)[0] >= -.05, output.data.max(1)[1].byte()),1)
  41. confident = torch.cat((confident,batch_results), 0)
  42.  
  43. if epoch > 25:
  44.  
  45. #Create new data sets based on whether we are confident or not
  46. new_unlabeled_data = torch.ByteTensor()
  47.  
  48. for i in range(len(ss_unlabeled.train_data)):
  49. if torch.equal(confident[i][0], torch.ByteTensor([1])):
  50. ss_train.train_data = torch.cat((ss_train.train_data,ss_unlabeled.train_data[i].unsqueeze(0)))
  51. ss_train.train_labels = torch.cat((ss_train.train_labels, confident[i][1].long()))
  52. ss_train.k += 1
  53. else:
  54. new_unlabeled_data = torch.cat((new_unlabeled_data,ss_unlabeled.train_data[i].unsqueeze(0)))
  55.  
  56. if i % 500 == 0:
  57. print("Finished assessing %s/%s of unlabeled data" % (i, len(ss_unlabeled.train_data)))
  58.  
  59. #Make the new unlabeled data
  60. ss_unlabeled.train_data = new_unlabeled_data
  61. ss_unlabeled.train_labels = torch.LongTensor(-1*np.ones(len(new_unlabeled_data)).astype(int))
  62. ss_unlabeled.k = len(new_unlabeled_data)
  63.  
  64. #Reload the data
  65. ss_train_loader = torch.utils.data.DataLoader(ss_train, batch_size=64, shuffle=True)
  66. ss_unlabeled_loader = torch.utils.data.DataLoader(ss_unlabeled, batch_size=64, shuffle=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement