Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- embed_ENG:nn.Embedding = nn.Embedding(num_embeddings=NUM_ENGLISH_WORDS, embedding_dim=128, padding_idx=PADDING_TOKEN)
- embed_FR:nn.Embedding = nn.Embedding(num_embeddings=NUM_FRENCH_WORDS, embedding_dim=128, padding_idx=PADDING_TOKEN)
- encoder:nn.GRU = nn.GRU(input_size=128, hidden_size=128,
- batch_first=True, dropout=DROPOUT_PROB, num_layers=2)
- decoder:nn.GRU = nn.GRU(input_size=128, hidden_size=128,
- batch_first=True, dropout=DROPOUT_PROB, num_layers=2)
- linear:nn.Linear = nn.Linear(in_features=128, out_features=NUM_FRENCH_WORDS)
- firstBatch = next(iter(test_loader))
- X, Y, X_sequence_lengths, Y_sequence_lengths = firstBatch
- # X, Y, sequence_lengths = X.to(DEVICE), Y.to(DEVICE), sequence_lengths.to(DEVICE)
- X_sequence_lengths_sorted, X_sortedIndices = X_sequence_lengths.sort(descending=True)
- _, X_unsortedIndices = X_sortedIndices.sort()
- X_Sorted = X[X_sortedIndices]
- Y_Sorted = Y[X_sortedIndices]
- X_Sorted_Embedded = embed_ENG(X_Sorted)
- Y_Sorted_Embedded = embed_FR(Y_Sorted)
- X_Sorted_Embedded_Packed = pack_padded_sequence(X_Sorted_Embedded, X_sequence_lengths_sorted.cpu(), batch_first=True, enforce_sorted=True)
- # Pass input through encoder
- encoder_output_sorted_packed, encoder_output_last_hidden_state = encoder(X_Sorted_Embedded_Packed)
- # Pass Input through decoder
- decoded_outputs_sorted = []
- decoder_first_hidden_state = encoder_output_last_hidden_state
- decoder_first_input = torch.full(size=(32,1), fill_value=EOS_TOKEN)
- decoder_first_input_embedded = embed_FR(decoder_first_input)
- decoder_output, next_decoder_hidden_state = decoder(decoder_first_input_embedded, decoder_first_hidden_state)
- decoder_output_sorted_FC = linear(decoder_output)
- decoded_outputs_sorted.append(decoder_output_sorted_FC)
- for sequence_index in range(Y.shape[1] - 1):
- sequence_n_words = Y_Sorted[:, sequence_index].unsqueeze(dim=1)
- sequence_n_words_embedded = embed_FR(sequence_n_words)
- decoder_output, next_decoder_hidden_state = decoder(sequence_n_words_embedded, next_decoder_hidden_state)
- decoder_output_sorted_FC = linear(decoder_output)
- decoded_outputs_sorted.append(decoder_output_sorted_FC)
- print(f"Output Hidden Shape: {next_decoder_hidden_state.shape}")
- print(f"Output Hidden: {encoder_output_last_hidden_state}")
- decoded_outputs_sorted = torch.stack(decoded_outputs_sorted, dim=1).squeeze()
- decoded_outputs = decoded_outputs_sorted[X_unsortedIndices]
- print(f"Output Shape: {decoded_outputs.shape}")
- print(f"Output: {decoded_outputs}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement