Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def test_model(model):
- #TODO load test properly
- model.eval()
- total = 0
- correct = 0
- with torch.no_grad():
- for i, (texts, masks, start_pos, end_pos) in enumerate(dev_data_loader):
- _, probs = model(texts.to(device),
- mask=masks.to(device),
- start_positions=torch.tensor(start_pos).to(device),
- end_positions=torch.tensor(end_pos).to(device))
- start, end = get_best(probs)
- correct += torch.sum((start == start_positions) * (end == end_positions))
- total += len(start)
- print(f'Accuracy on dev data is {correct / total}')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement