Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Michael A. Alcorn
- import torch
- import torch.autograd as autograd
- import torch.nn as nn
- def create_lstm(params):
- """Create a LSTM from a dictionary of parameters.
- :param params:
- :return:
- """
- return nn.LSTM(**params)
- def create_h_0_c_0(params):
- """Create variables containing LSTM initial hidden state.
- :param params:
- :return:
- """
- num_directions = 2 if params["bidirectional"] else 1
- l_by_d = params["num_layers"] * num_directions
- params["num_directions"] = num_directions
- params["l_by_d"] = l_by_d
- hidden_size = params["hidden_size"]
- h_0_var = autograd.Variable(torch.randn(l_by_d, hidden_size), requires_grad = True)
- c_0_var = autograd.Variable(torch.randn(l_by_d, hidden_size), requires_grad = True)
- return (h_0_var, c_0_var)
- # Define the size of the input at each step for the encoder and decoder.
- input_size = {"e": 6, "d": 4}
- # Create the encoder.
- e_d = {"e": {"input_size": input_size["e"],
- "hidden_size": 5,
- "num_layers": 3,
- "batch_first": True,
- "bidirectional": True}}
- encoder = create_lstm(e_d["e"])
- # Create the initial hidden state variables for the encoder.
- h_0 = {}
- c_0 = {}
- (h_0["e"], c_0["e"]) = create_h_0_c_0(e_d["e"])
- # Create the decoder.
- input_size_d = input_size["d"] + e_d["e"]["num_directions"] * e_d["e"]["hidden_size"]
- e_d["d"] = {"input_size": input_size_d,
- "hidden_size": 9,
- "num_layers": 2,
- "batch_first": True,
- "bidirectional": False}
- decoder = create_lstm(e_d["d"])
- # Create initial hidden state variables for the decoder.
- (h_0["d"], c_0["d"]) = create_h_0_c_0(e_d["d"])
- # Create the attention mechanism.
- attn = nn.Linear(e_d["e"]["num_directions"] * e_d["e"]["hidden_size"] + e_d["d"]["hidden_size"], 1)
- attn_weights = nn.Softmax(dim = 1)
- # Create dummy input and output sequences.
- seq_lens = {"e": 7, "d": 8}
- seqs = {}
- for (e_or_d, seq_len) in seq_lens.items():
- x = [autograd.Variable(torch.randn((1, input_size[e_or_d]))) for _ in range(seq_len)]
- seqs[e_or_d] = torch.cat(x).view(1, len(x), input_size[e_or_d])
- # Calculate encoder outputs.
- (out, (h_e, c_e)) = encoder(seqs["e"], (h_0["e"].unsqueeze(1), c_0["e"].unsqueeze(1)))
- # Calculate hidden states for decoder using attention mechanism.
- (h_t, c_t) = (h_0["d"].unsqueeze(1), c_0["d"].unsqueeze(1))
- num_directions_d = e_d["d"]["num_directions"]
- hidden_size_d = e_d["d"]["hidden_size"]
- input_size_d = input_size["d"]
- for i in range(seqs["d"].size(1)):
- # h[0] is the output of the bottom layer.
- h_att = h_t[0].squeeze(1).unsqueeze(0)
- h_ex = h_att.expand(1, seq_lens["e"], hidden_size_d)
- concat = torch.cat((out, h_ex), 2)
- att = attn(concat)
- att_w = attn_weights(att)
- attn_applied = torch.bmm(att_w.view(1, 1, seq_lens["e"]),
- out)
- new_input = torch.cat((seqs["d"][0, i].view(1, 1, input_size_d), attn_applied), dim = 2)
- (out_t, (h_t, c_t)) = decoder(new_input, (h_t, c_t))
Add Comment
Please, Sign In to add comment