Advertisement
Guest User

Untitled

a guest
Sep 21st, 2019
176
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.36 KB | None | 0 0
  1. def train_epoch(model, optimizer, train_loader, criterion):
  2. model.train()
  3. for X, y in train_loader:
  4. out = model(X.to(device))
  5. loss = criterion(out, y.to(device))
  6. optimizer.zero_grad()
  7. loss.backward()
  8. optimizer.step()
  9. return
  10.  
  11. def evaluate_loss(loader, model):
  12. with torch.no_grad():
  13. model.eval()
  14. loss = 0
  15. k = 0
  16. for X, y in loader:
  17. k += 1
  18. out = model(X.to(device))
  19. loss += criterion(out, y.to(device)).cpu().numpy()
  20. loss /= k + 1
  21. return loss
  22.  
  23. def train(model, opt, train_loader, test_loader, criterion, n_epochs, writer=False, verbose=True, save=False,
  24. scheduler=False):
  25. for epoch in range(n_epochs):
  26. train_epoch(model, opt, train_loader, criterion)
  27. train_loss = evaluate_loss(train_loader, model)
  28. val_loss = evaluate_loss(test_loader, model)
  29. if writer:
  30. writer.add_scalars('loss', {'validation':val_loss, 'train':train_loss}, global_step=epoch)
  31. if save:
  32. torch.save(model.state_dict(), experiment_path + 'model')
  33. if verbose:
  34. print(('Epoch [%d/%d], Loss (train/test): %.6f/%.6f,')
  35. %(epoch+1, n_epochs, \
  36. train_loss, val_loss))
  37. if scheduler:
  38. scheduler.step()
  39. return
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement