Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train(config, checkpoint_dir=None, data=None):
- # data = (X_2, original)
- loss_fn = torch.nn.MSELoss()
- model = autoencoder(config)
- optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9)
- maxIter = 20000
- batchAmount = config['batchSize']
- if checkpoint_dir:
- checkpoint = os.path.join(checkpoint_dir, "checkpoint")
- model_state, optimizer_state = torch.load(checkpoint)
- model.load_state_dict(model_state)
- optimizer.load_state_dict(optimizer_state)
- for t in range(maxIter):
- epoch_loss = 0
- optimizer.zero_grad()
- idx = np.random.randint(data[0].shape[0], size=batchAmount) # bootstrapping a subset of the total samples
- X_scaled = torch.unsqueeze(torch.from_numpy(data[0][idx, :]).float(), dim=1) # creating tensor for convultion
- testValues = torch.from_numpy(
- np.reshape(data[1][idx, :],
- (batchAmount, -1))
- ).float() # creating a flattened array for testing
- y_pred = model(X_scaled) # predict on the subset
- loss = loss_fn(testValues, y_pred) # get loss on subset
- epoch_loss += loss.item()
- if not t == 0:
- if t % (maxIter / 10) == 0:
- # print(t, loss.item())
- tune.report(score=epoch_loss)
- with tune.checkpoint_dir(step=t) as checkpoint_dir:
- path = os.path.join(checkpoint_dir, "checkpoint")
- torch.save(
- (model.state_dict(), optimizer.state_dict()), path)
- loss.backward() # get gradient stuff
- optimizer.step() # optimize
- tune.report(score=epoch_loss)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement