Advertisement
Guest User

Untitled

a guest
Apr 5th, 2020
259
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.89 KB | None | 0 0
  1. class LSTM(nn.Module):
  2.  
  3.     def __init__(self, input_size, hidden_size, batch_first=True):
  4.         """Initialize params."""
  5.         super(PersonaLSTMAttentionDot, self).__init__()
  6.         self.input_size = input_size
  7.         self.hidden_size = hidden_size
  8.         self.num_layers = 1
  9.         self.batch_first = batch_first
  10.  
  11.         self.input_weights = nn.Linear(input_size, 4 * hidden_size)
  12.         self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size)
  13.  
  14.     def forward(self, input, hidden, ctx, ctx_mask=None):
  15.         """Propogate input through the network."""
  16.         # tag = None  #
  17.         def recurrence(input, hidden):
  18.             """Recurrence helper."""
  19.             hx, cx = hidden  # n_b x hidden_dim
  20.             gates = self.input_weights(input) + \
  21.                 self.hidden_weights(hx)
  22.             ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
  23.  
  24.             ingate = F.sigmoid(ingate)
  25.             forgetgate = F.sigmoid(forgetgate)
  26.             cellgate = F.tanh(cellgate)  # o_t
  27.             outgate = F.sigmoid(outgate)
  28.  
  29.             cy = (forgetgate * cx) + (ingate * cellgate)
  30.             hy = outgate * F.tanh(cy)  # n_b x hidden_dim
  31.  
  32.             return hy, cy
  33.  
  34.         if self.batch_first:
  35.             input = input.transpose(0, 1)
  36.  
  37.         output = []
  38.         steps = range(input.size(0))
  39.         for i in steps:
  40.             hidden = recurrence(input[i], hidden)
  41.             if isinstance(hidden, tuple):
  42.                 output.append(hidden[0])
  43.             else:
  44.                 output.append(hidden)
  45.  
  46.             # output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
  47.             # output.append(isinstance(hidden, tuple) and hidden[0] or hidden)
  48.  
  49.         output = torch.cat(output, 0).view(input.size(0), *output[0].size())
  50.  
  51.         if self.batch_first:
  52.             output = output.transpose(0, 1)
  53.  
  54.         return output, hidden
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement