Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.63 KB | None | 0 0
  1. def test_model(model):
  2.   #TODO load test properly
  3.   model.eval()
  4.   total = 0
  5.   correct = 0
  6.   with torch.no_grad():
  7.     for i, (texts, masks, start_pos, end_pos) in enumerate(dev_data_loader):
  8.         _, probs = model(texts.to(device),
  9.                         mask=masks.to(device),
  10.                         start_positions=torch.tensor(start_pos).to(device),
  11.                         end_positions=torch.tensor(end_pos).to(device))
  12.         start, end = get_best(probs)
  13.         correct += torch.sum((start == start_positions) * (end == end_positions))
  14.         total += len(start)
  15.   print(f'Accuracy on dev data is {correct / total}')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement