Guest User

Untitled

a guest
Nov 16th, 2018
144
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.49 KB | None | 0 0
  1. class RNN(nn.Module):
  2.  
  3. def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
  4. super().__init__()
  5.  
  6. self.embedding = nn.Embedding.from_pretrained(weights)
  7. self.rnn = nn.RNN(embedding_dim, hidden_dim)
  8. self.fc = nn.Linear(hidden_dim, output_dim)
  9.  
  10. def forward(self, x):
  11.  
  12. embedded = self.embedding(x)
  13.  
  14. output, hidden = self.rnn(embedded)
  15.  
  16. assert torch.equal(output[-1,:,:], hidden.squeeze(0))
  17.  
  18. return self.fc(hidden.squeeze(0))
  19.  
  20.  
  21.  
  22. def train(model, iterator, optimizer, criterion):
  23.  
  24. epoch_loss = 0
  25. epoch_acc = 0
  26.  
  27. model.train()
  28.  
  29. for batch in iterator:
  30.  
  31.  
  32.  
  33. optimizer.zero_grad()
  34.  
  35. predictions = model(batch["text"]).squeeze(1)
  36.  
  37.  
  38.  
  39. loss = criterion(predictions, batch["label"])
  40.  
  41. acc = binary_accuracy(predictions, batch["label"])
  42.  
  43. loss.backward()
  44.  
  45. optimizer.step()
  46.  
  47. epoch_loss += loss.item()
  48. epoch_acc += acc.item()
  49.  
  50. return epoch_loss / len(iterator), epoch_acc / len(iterator)
  51.  
  52. def evaluate(model, iterator, criterion):
  53.  
  54. epoch_loss = 0
  55. epoch_acc = 0
  56.  
  57. model.eval()
  58.  
  59. with torch.no_grad():
  60.  
  61. for batch in iterator:
  62.  
  63. predictions = model(batch["text"]).squeeze(1)
  64.  
  65. loss = criterion(predictions, batch["label"])
  66.  
  67. acc = binary_accuracy(predictions, batch["label"])
  68.  
  69. epoch_loss += loss.item()
  70. epoch_acc += acc.item()
  71.  
  72. return epoch_loss / len(iterator), epoch_acc / len(iterator)
Add Comment
Please, Sign In to add comment