Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class RNN(nn.Module):
- def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
- super().__init__()
- self.embedding = nn.Embedding.from_pretrained(weights)
- self.rnn = nn.RNN(embedding_dim, hidden_dim)
- self.fc = nn.Linear(hidden_dim, output_dim)
- def forward(self, x):
- embedded = self.embedding(x)
- output, hidden = self.rnn(embedded)
- assert torch.equal(output[-1,:,:], hidden.squeeze(0))
- return self.fc(hidden.squeeze(0))
- def train(model, iterator, optimizer, criterion):
- epoch_loss = 0
- epoch_acc = 0
- model.train()
- for batch in iterator:
- optimizer.zero_grad()
- predictions = model(batch["text"]).squeeze(1)
- loss = criterion(predictions, batch["label"])
- acc = binary_accuracy(predictions, batch["label"])
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- epoch_acc += acc.item()
- return epoch_loss / len(iterator), epoch_acc / len(iterator)
- def evaluate(model, iterator, criterion):
- epoch_loss = 0
- epoch_acc = 0
- model.eval()
- with torch.no_grad():
- for batch in iterator:
- predictions = model(batch["text"]).squeeze(1)
- loss = criterion(predictions, batch["label"])
- acc = binary_accuracy(predictions, batch["label"])
- epoch_loss += loss.item()
- epoch_acc += acc.item()
- return epoch_loss / len(iterator), epoch_acc / len(iterator)
Add Comment
Please, Sign In to add comment