Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # def train_epoch(net, optimizer, criterion, dataloader, pad_index):
- # print_count = 10
- # print_every = len(dataloader) // print_count
- # running_loss = 0.0
- # batches_ran = 0
- # for iteration, x in tqdm_notebook(enumerate(dataloader)):
- # batches_ran += 1
- # optimizer.zero_grad()
- # splits = np.random.randint(0, x.shape[1]-1, size=x.shape[0])
- # x = x.cuda()
- # labels = torch.empty(x.shape[0],dtype=torch.long).cuda()
- # for i,spl in enumerate(splits):
- # idx = x[i,spl].item()
- # x[i,spl:] = pad_index
- # labels[i] = idx
- # output, h = net(x)
- # loss = criterion(output, labels)
- # loss.backward()
- # running_loss += loss.item()
- # optimizer.step()
- # if iteration % print_every == print_every - 1:
- # print("iteration {} loss {}".format(iteration, running_loss / batches_ran))
- # running_loss = 0.0
- # batches_ran = 0
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement