Guest User

Untitled

a guest
Nov 19th, 2017
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.24 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. from torch.nn import Parameter
  5. import numpy as np
  6. import math
  7.  
  8. # TODO: Your implementation goes here
  9. class Encoder(nn.Module):
  10. def __init__(self, vocab_size):
  11. super(Encoder, self).__init__()
  12.  
  13. self.embedding = torch.nn.Embedding(vocab_size, 300)
  14.  
  15. self.lstm = torch.nn.LSTM(input_size=300, hidden_size=512, num_layers=1, bidirectional=True)
  16.  
  17. self.hidden = None
  18.  
  19.  
  20. def forward(self, input):
  21. embedded = self.embedding(input)
  22. output = embedded
  23. output, self.hidden = self.lstm(output, self.hidden)
  24. return output, self.hidden
  25.  
  26.  
  27. class Attn(nn.Module):
  28. def __init__(self, method, hidden_size):
  29. super(Attn, self).__init__()
  30.  
  31. self.method = method
  32. self.hidden_size = hidden_size
  33.  
  34. self.softmax = nn.LogSoftmax()
  35.  
  36. self.attn = nn.Linear(self.hidden_size, hidden_size)
  37.  
  38. def forward(self, hidden, encoder_outputs):
  39. max_len = encoder_outputs.size(0)
  40. this_batch_size = encoder_outputs.size(1)
  41.  
  42. # Create variable to store attention energies
  43. attn_energies = Variable(torch.zeros(this_batch_size, max_len)) # B x S
  44.  
  45. # For each batch of encoder outputs
  46. for b in range(this_batch_size):
  47. # Calculate energy for each encoder output
  48. for i in range(max_len):
  49. attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
  50.  
  51. # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
  52. return self.softmax(attn_energies).unsqueeze(1)
  53.  
  54. def score(self, hidden, encoder_output):
  55. energy = self.attn(encoder_output)
  56. energy = hidden.dot(energy)
  57. return energy
  58.  
  59. class Decoder(nn.Module):
  60. def __init__(self, src_vocab_size, trg_vocab_size):
  61. super(Decoder, self).__init__()
  62.  
  63. hidden_size = 1024
  64.  
  65. self.initParams = torch.load(open("model.param", "rb"))
  66.  
  67. self.embedding = torch.nn.Embedding(trg_vocab_size, 300)
  68.  
  69.  
  70. self.attn = Attn('general', 1024)
  71.  
  72. self.attin = torch.nn.Linear(1024, 1024)
  73. self.softmax = nn.LogSoftmax()
  74.  
  75. self.attout = torch.nn.Linear(2048, 1024)
  76.  
  77.  
  78. self.lstm = torch.nn.LSTM(input_size=1324, hidden_size=1024, num_layers=1)
  79.  
  80. #23262
  81. self.gen = torch.nn.Linear(1024,trg_vocab_size)
  82. self.hidden = None#Parameter(torch.randn((48, 1024)))
  83. self.prev = Variable(torch.randn((48, 1024)))
  84.  
  85. self.concat = nn.Linear(hidden_size * 2, hidden_size)
  86.  
  87.  
  88. def forward(self, targ, encoder_out):
  89. embedded = self.embedding(targ)
  90. output = embedded
  91.  
  92. #sc = Variable(torch.zeros((len(encoder_out),48,)))
  93. #for i in range(len(encoder_out)):
  94. # sc[i] = self.score(encoder_out[i], self.hidden)
  95. #a = self.softmax(sc)
  96. #a = a.repeat(1024,1,1).transpose(0,1).transpose(1,2)
  97. ##broadcasting
  98. #mult = torch.mul(a, encoder_out)
  99. #s = torch.sum(mult, 0)
  100. #
  101. #context = torch.tanh(self.attout(torch.cat((s, self.hidden), 1)))
  102. #output = torch.cat((context.repeat(len(embedded), 1,1), embedded), 2
  103.  
  104. sc = Variable(torch.zeros((len(encoder_out),48,)))
  105. for i in range(len(encoder_out)):
  106. sc[i] = self.score(encoder_out[i], self.prev)
  107. a = self.softmax(sc)
  108. context = a.bmm(encoder_out.transpose(0, 1))
  109.  
  110. prev = self.prev.squeeze(0) # S=1 x B x N -> B x N
  111. context = context.squeeze(1) # B x S=1 x N -> B x N
  112. concat_input = torch.cat((prev, context), 1)
  113. concat_output = torch.tanh(self.concat(concat_input))
  114.  
  115. #a = a.repeat(1024,1,1).transpose(0,1).transpose(1,2)
  116. #broadcasting
  117. #mult = torch.mul(encoder_out, a.unsqueeze(2).repeat(1, 1, 1024))
  118. #s = torch.sum(mult, 0)
  119. #
  120. #context = torch.tanh(self.attout(torch.cat((s, self.prev), 1)))
  121. #output = torch.cat((context.repeat(len(embedded), 1,1), embedded), 2)
  122.  
  123. #--------------------------------
  124.  
  125. #output, hiddenN = self.lstm(embedded, self.hidden)
  126.  
  127. ## Calculate attention from current RNN state and all encoder outputs;
  128. ## apply to encoder outputs to get weighted average
  129. #attn_weights = self.attn(output, encoder_outputs)
  130. #context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x S=1 x N
  131. #
  132. ## Attentional vector using the RNN hidden state and context vector
  133. ## concatenated together (Luong eq. 5)
  134. #rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
  135. #context = context.squeeze(1) # B x S=1 x N -> B x N
  136. #concat_input = torch.cat((rnn_output, context), 1)
  137. #concat_output = F.tanh(self.concat(concat_input))
  138.  
  139.  
  140. self.prev, self.hidden = self.lstm(concat_out, self.hidden)#Variable(torch.zeros((2, 48, 1024,))))
  141. generated = self.gen(self.prev)
  142. return generated
  143.  
  144. def score(self, h_s, h_t):
  145. h_t = self.attin(h_t)
  146. print h_s
  147. return h_t.unsqueeze(1).bmm(h_s.unsqueeze(1).transpose(1,2))
  148. #a = Variable(torch.zeros((48,1)))
  149. #for i in range(len(h_t)):
  150. # a[i] = torch.dot(h_s[i], h_t[i])
  151. #return a.transpose(0,1)
  152.  
  153.  
  154.  
  155. class NMT(nn.Module):
  156. def __init__(self, src_vocab_size, trg_vocab_size):
  157. super(NMT, self).__init__()
  158. self.Encoder = Encoder(src_vocab_size)
  159. self.Decoder = Decoder(src_vocab_size, trg_vocab_size)
  160.  
  161. def forward(self, input, targ):
  162.  
  163. encout, hidden = self.Encoder(input)
  164.  
  165. return self.Decoder(targ, encout)
Add Comment
Please, Sign In to add comment