Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.42 KB | None | 0 0
  1. from torch.optim import SGD
  2. from torchnet.meter import AverageValueMeter
  3. from torchnet.logger import VisdomPlotLogger, VisdomSaver
  4. from sklearn.metrics import accuracy_score
  5.  
  6.  
  7. def train_classifier(model, train_loader, test_loader, exp_name='experiment', lr=0.01, epochs=10, momentum=0.99):
  8. criterion = nn.CrossEntropyLoss()
  9. optimizer = SGD(model.parameters(), lr, momentum=momentum)
  10. #meters
  11. loss_meter = AverageValueMeter()
  12. acc_meter = AverageValueMeter()
  13. #plotters
  14. loss_logger = VisdomPlotLogger('line', env=exp_name, opts={'title': 'Loss', 'legend':['train','test']})
  15. acc_logger = VisdomPlotLogger('line', env=exp_name, opts={'title': 'Accuracy','legend':['train','test']})
  16. visdom_saver = VisdomSaver(envs=[exp_name])
  17. #device
  18. device = "cuda" if torch.cuda.is_available() else "cpu"
  19. model.to(device)
  20. #definiamo un dizionario contenente i loader di training e test
  21. loader = {
  22. 'train' : train_loader,
  23. 'test' : test_loader
  24. }
  25. for e in range(epochs):
  26. #iteriamo tra due modalità: train e test
  27. for mode in ['train','test'] :
  28. loss_meter.reset(); acc_meter.reset()
  29. model.train() if mode == 'train' else model.eval()
  30. with torch.set_grad_enabled(mode=='train'): #abilitiamo i gradienti solo in training
  31. for i, batch in enumerate(loader[mode]):
  32. x=batch[0].to(device) #"portiamoli sul device corretto"
  33. y=batch[1].to(device)
  34. output = model(x)
  35. l = criterion(output,y)
  36. if mode=='train':
  37. l.backward()
  38. optimizer.step()
  39. optimizer.zero_grad()
  40. acc = accuracy_score(y.to('cpu'),output.to('cpu').max(1)[1])
  41. n = batch[0].shape[0] #numero di elementi nel batch
  42. loss_meter.add(l.item()*n,n)
  43. acc_meter.add(acc*n,n)
  44. if mode=='train':
  45. loss_logger.log(e+(i+1)/len(loader[mode]), loss_meter.value()[0], name=mode)
  46. acc_logger.log(e+(i+1)/len(loader[mode]), acc_meter.value()[0], name=mode)
  47. loss_logger.log(e+1, loss_meter.value()[0], name=mode)
  48. acc_logger.log(e+1, acc_meter.value()[0], name=mode)
  49. torch.save(model.state_dict(),'%s-%d.pth'%(exp_name,e+1))
  50. return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement