Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def ss_training(epoch):
- model.train()
- global ss_train_loader
- global ss_unlabeled_loader
- global ss_train
- global ss_unlabeled
- #Alternate training on real data and data we're not sure about
- if epoch % 2 == 0:
- epoch_loader = ss_train_loader
- else:
- epoch_loader = train_labeled_loader
- for batch_idx, (data, target) in enumerate(epoch_loader):
- data, target = Variable(data), Variable(target)
- optimizer.zero_grad()
- output = model(data)
- loss = F.nll_loss(output, target)
- loss.backward()
- optimizer.step()
- if batch_idx % 10 == 0:
- print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- epoch, batch_idx * len(data), len(epoch_loader.dataset),
- 100. * batch_idx / len(epoch_loader), loss.data[0]))
- #Evaluate on unlabeled
- model.eval()
- confident = torch.ByteTensor()
- for data, target in ss_unlabeled_loader:
- #Run the model on the unlabeled data
- data, target = Variable(data, volatile=True), Variable(target)
- output=model(data)
- #Create a list indicating the labels that we are 95% confident about
- batch_results = torch.stack((output.data.max(1)[0] >= -.05, output.data.max(1)[1].byte()),1)
- confident = torch.cat((confident,batch_results), 0)
- if epoch > 25:
- #Create new data sets based on whether we are confident or not
- new_unlabeled_data = torch.ByteTensor()
- for i in range(len(ss_unlabeled.train_data)):
- if torch.equal(confident[i][0], torch.ByteTensor([1])):
- ss_train.train_data = torch.cat((ss_train.train_data,ss_unlabeled.train_data[i].unsqueeze(0)))
- ss_train.train_labels = torch.cat((ss_train.train_labels, confident[i][1].long()))
- ss_train.k += 1
- else:
- new_unlabeled_data = torch.cat((new_unlabeled_data,ss_unlabeled.train_data[i].unsqueeze(0)))
- if i % 500 == 0:
- print("Finished assessing %s/%s of unlabeled data" % (i, len(ss_unlabeled.train_data)))
- #Make the new unlabeled data
- ss_unlabeled.train_data = new_unlabeled_data
- ss_unlabeled.train_labels = torch.LongTensor(-1*np.ones(len(new_unlabeled_data)).astype(int))
- ss_unlabeled.k = len(new_unlabeled_data)
- #Reload the data
- ss_train_loader = torch.utils.data.DataLoader(ss_train, batch_size=64, shuffle=True)
- ss_unlabeled_loader = torch.utils.data.DataLoader(ss_unlabeled, batch_size=64, shuffle=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement