Advertisement
Guest User

Untitled

a guest
Nov 30th, 2020
40
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.71 KB | None | 0 0
  1. def train(config, checkpoint_dir=None, data=None):
  2. # data = (X_2, original)
  3. loss_fn = torch.nn.MSELoss()
  4. model = autoencoder(config)
  5. optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9)
  6. maxIter = 20000
  7. batchAmount = config['batchSize']
  8.  
  9. if checkpoint_dir:
  10. checkpoint = os.path.join(checkpoint_dir, "checkpoint")
  11. model_state, optimizer_state = torch.load(checkpoint)
  12. model.load_state_dict(model_state)
  13. optimizer.load_state_dict(optimizer_state)
  14.  
  15. for t in range(maxIter):
  16. epoch_loss = 0
  17.  
  18. optimizer.zero_grad()
  19. idx = np.random.randint(data[0].shape[0], size=batchAmount) # bootstrapping a subset of the total samples
  20.  
  21. X_scaled = torch.unsqueeze(torch.from_numpy(data[0][idx, :]).float(), dim=1) # creating tensor for convultion
  22.  
  23. testValues = torch.from_numpy(
  24. np.reshape(data[1][idx, :],
  25. (batchAmount, -1))
  26. ).float() # creating a flattened array for testing
  27.  
  28. y_pred = model(X_scaled) # predict on the subset
  29.  
  30. loss = loss_fn(testValues, y_pred) # get loss on subset
  31. epoch_loss += loss.item()
  32.  
  33. if not t == 0:
  34. if t % (maxIter / 10) == 0:
  35. # print(t, loss.item())
  36. tune.report(score=epoch_loss)
  37. with tune.checkpoint_dir(step=t) as checkpoint_dir:
  38. path = os.path.join(checkpoint_dir, "checkpoint")
  39. torch.save(
  40. (model.state_dict(), optimizer.state_dict()), path)
  41.  
  42. loss.backward() # get gradient stuff
  43. optimizer.step() # optimize
  44.  
  45. tune.report(score=epoch_loss)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement