Advertisement
Mikestriken

seq2seq packed sequence batched forward pass

Mar 19th, 2025
173
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.51 KB | None | 0 0
  1. embed_ENG:nn.Embedding = nn.Embedding(num_embeddings=NUM_ENGLISH_WORDS, embedding_dim=128, padding_idx=PADDING_TOKEN)
  2. embed_FR:nn.Embedding = nn.Embedding(num_embeddings=NUM_FRENCH_WORDS, embedding_dim=128, padding_idx=PADDING_TOKEN)
  3. encoder:nn.GRU = nn.GRU(input_size=128, hidden_size=128,
  4.                     batch_first=True, dropout=DROPOUT_PROB, num_layers=2)
  5.  
  6. decoder:nn.GRU = nn.GRU(input_size=128, hidden_size=128,
  7.                     batch_first=True, dropout=DROPOUT_PROB, num_layers=2)
  8.  
  9. linear:nn.Linear = nn.Linear(in_features=128, out_features=NUM_FRENCH_WORDS)
  10.  
  11. firstBatch = next(iter(test_loader))
  12. X, Y, X_sequence_lengths, Y_sequence_lengths = firstBatch
  13. # X, Y, sequence_lengths = X.to(DEVICE), Y.to(DEVICE), sequence_lengths.to(DEVICE)
  14.  
  15. X_sequence_lengths_sorted, X_sortedIndices = X_sequence_lengths.sort(descending=True)
  16. _, X_unsortedIndices = X_sortedIndices.sort()
  17.  
  18.  
  19. X_Sorted = X[X_sortedIndices]
  20. Y_Sorted = Y[X_sortedIndices]
  21.  
  22. X_Sorted_Embedded = embed_ENG(X_Sorted)
  23. Y_Sorted_Embedded = embed_FR(Y_Sorted)
  24.  
  25. X_Sorted_Embedded_Packed = pack_padded_sequence(X_Sorted_Embedded, X_sequence_lengths_sorted.cpu(), batch_first=True, enforce_sorted=True)
  26.      
  27. # Pass input through encoder
  28. encoder_output_sorted_packed, encoder_output_last_hidden_state = encoder(X_Sorted_Embedded_Packed)
  29.  
  30. # Pass Input through decoder
  31. decoded_outputs_sorted = []
  32. decoder_first_hidden_state = encoder_output_last_hidden_state
  33. decoder_first_input = torch.full(size=(32,1), fill_value=EOS_TOKEN)
  34. decoder_first_input_embedded = embed_FR(decoder_first_input)
  35.  
  36. decoder_output, next_decoder_hidden_state = decoder(decoder_first_input_embedded, decoder_first_hidden_state)
  37.  
  38. decoder_output_sorted_FC = linear(decoder_output)
  39. decoded_outputs_sorted.append(decoder_output_sorted_FC)
  40.  
  41. for sequence_index in range(Y.shape[1] - 1):
  42.     sequence_n_words = Y_Sorted[:, sequence_index].unsqueeze(dim=1)
  43.     sequence_n_words_embedded = embed_FR(sequence_n_words)
  44.     decoder_output, next_decoder_hidden_state = decoder(sequence_n_words_embedded, next_decoder_hidden_state)
  45.     decoder_output_sorted_FC = linear(decoder_output)
  46.     decoded_outputs_sorted.append(decoder_output_sorted_FC)
  47.    
  48.  
  49. print(f"Output Hidden Shape: {next_decoder_hidden_state.shape}")
  50. print(f"Output Hidden: {encoder_output_last_hidden_state}")
  51.  
  52. decoded_outputs_sorted = torch.stack(decoded_outputs_sorted, dim=1).squeeze()
  53. decoded_outputs = decoded_outputs_sorted[X_unsortedIndices]
  54. print(f"Output Shape: {decoded_outputs.shape}")
  55. print(f"Output: {decoded_outputs}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement