Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class _EncoderModule(nn.Module):
- def __init__(self, embeddings_table, embeddings_size, hidden_size):
- super().__init__()
- self.embeddings_table = embeddings_table
- self.embeddings_size = embeddings_size
- self.hidden_size = hidden_size
- self.lstm = nn.LSTM(input_size=embeddings_size, hidden_size=hidden_size, bidirectional=True)
- def forward(self, src_sentences, src_lengths):
- src_embed = self.embeddings_table(src_sentences)
- packd_embed = pack_padded_sequence(src_embed, src_lengths, batch_first=True)
- packd_all_hidden, (hidden_states, cell_states) = self.lstm(packd_embed)
- all_hidden_states, _ = pad_packed_sequence(packd_all_hidden, batch_first=True)
- return all_hidden_states, hidden_states, cell_states
- class _DecoderModule(nn.Module):
- def __init__(
- self,
- embeddings_table,
- embeddings_size,
- hidden_size,
- start_idx,
- dropout_prob,
- dst_vocab_size,
- teacher_forcing
- ):
- super().__init__()
- self.embeddings_table = embeddings_table
- self.embeddings_size = embeddings_size
- self.hidden_size = hidden_size
- self.start_idx = torch.tensor(start_idx).to(DEVICE)
- self.W_h = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
- self.W_c = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
- self.W_attn = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
- self.W_u = torch.nn.Linear(3*self.hidden_size, self.hidden_size, bias=False)
- self.W_vocab = torch.nn.Linear(self.hidden_size, len(train_dataset.src_token2id), bias=False)
- self.lstm_cell = torch.nn.LSTMCell(self.hidden_size + self.embeddings_size, self.hidden_size, bias=True)
- self.dropout = torch.nn.Dropout(p=DROPOUT_PROB)
- self.teacher_forcing = teacher_forcing
- def forward(
- self,
- all_enc_hidden_states,
- final_enc_hidden_states,
- final_enc_cell_states,
- max_sentence_length,
- correct_indices=None
- ):
- BATCH_SIZE = all_enc_hidden_states.shape[0]
- out = []
- proj_final_enc_hidden_states = self.W_h(final_enc_hidden_states.transpose(0, 1).contiguous().view(BATCH_SIZE, 2*HIDDEN_SIZE))
- proj_final_enc_cell_states = self.W_c(final_enc_cell_states.transpose(0, 1).contiguous().view(BATCH_SIZE, 2*HIDDEN_SIZE))
- y_t = self.embeddings_table(self.start_idx.repeat(BATCH_SIZE))
- state = (proj_final_enc_hidden_states, proj_final_enc_cell_states)
- o_t = torch.zeros(BATCH_SIZE, HIDDEN_SIZE, device=DEVICE)
- for i in range(max_sentence_length):
- y_barra = torch.cat((y_t, o_t), dim=1)
- state = self.lstm_cell(y_barra, state)
- h_t, _ = state
- e_t = torch.bmm(self.W_attn(all_enc_hidden_states), h_t.unsqueeze(2))
- alfa_t = torch.softmax(e_t, dim=1) # b_size * m * 1
- a_t = torch.sum(alfa_t * all_enc_hidden_states, dim=1)
- u_t = torch.cat((h_t, a_t), dim=1)
- v_t = self.W_u(u_t)
- o_t = self.dropout(torch.tanh(v_t))
- P_t = self.W_vocab(o_t)
- out.append(P_t)
- _, max_indices = P_t.max(dim=1)
- if self.teacher_forcing and self.training: # 100% teacher forcing
- y_t = self.embeddings_table(correct_indices[:,i])
- else:
- y_t = self.embeddings_table(max_indices)
- return torch.stack(out, dim=1)
- class NeuralMachineTranslator(nn.Module):
- def __init__(
- self,
- src_vocab_size,
- start_idx,
- embeddings_size,
- hidden_size,
- dropout_prob,
- teacher_forcing=False
- ):
- super().__init__()
- self.embeddings_table = nn.Embedding(src_vocab_size, embeddings_size, padding_idx=0)
- self.encoder_module = _EncoderModule(
- embeddings_table=self.embeddings_table,
- embeddings_size=embeddings_size,
- hidden_size=hidden_size)
- self.decoder_module = _DecoderModule(
- embeddings_table=self.embeddings_table,
- embeddings_size=embeddings_size,
- hidden_size=hidden_size,
- start_idx=start_idx,
- dropout_prob=dropout_prob,
- dst_vocab_size=src_vocab_size,
- teacher_forcing=teacher_forcing)
- def forward(
- self,
- src_sentences,
- dst_sentences, # necesary when using teacher_forcing
- src_lengths,
- dst_lengths
- ):
- max_sentence_length = dst_lengths.max()
- # llamar encoder
- all_hidden_states, hidden_states, cell_states = self.encoder_module(src_sentences, src_lengths)
- # llamar decoder
- pre_logits = self.decoder_module(all_hidden_states, hidden_states, cell_states, max_sentence_length, dst_sentences)
- # ojo, se devuelven pre-softmax por la funcion crossentropy
- return pre_logits
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement