Advertisement
Guest User

Untitled

a guest
Dec 5th, 2016
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.58 KB | None | 0 0
  1. import numpy as np
  2. import chainer
  3. from chainer import Variable, optimizers, serializers, Chain
  4. import chainer.functions as F
  5. import chainer.links as L
  6.  
  7. # 翻訳クラス(Encoder-Decoder翻訳モデルにAttentionを導入したモデルを使う)
  8. class Translator(chainer.Chain):
  9. def __init__(self, debug = False, source = 'en.txt', target = 'ja.txt', embed_size = 100):
  10. self.embed_size = embed_size
  11.  
  12. self.source_lines, self.source_word2id, _ = self.load_language(source)
  13. self.target_lines, self.target_word2id, self.target_id2word = self.load_language(target)
  14.  
  15. source_size = len(self.source_word2id)
  16. target_size = len(self.target_word2id)
  17. super(Translator, self).__init__(
  18. embed_x = L.EmbedID(source_size, embed_size),
  19. embed_y = L.EmbedID(target_size, embed_size),
  20. H = L.LSTM(embed_size, embed_size),
  21. Wc1 = L.Linear(embed_size, embed_size),
  22. Wc2 = L.Linear(embed_size, embed_size),
  23. W = L.Linear(embed_size, target_size),
  24. )
  25. self.optimizer = optimizers.Adam()
  26. self.optimizer.setup(self)
  27.  
  28. if debug:
  29. print("embed_size: {0}".format(embed_size), end="")
  30. print(", source_size: {0}".format(source_size), end="")
  31. print(", target_size: {0}".format(target_size))
  32.  
  33. def learn(self, debug = False):
  34. line_num = len(self.source_lines) - 1
  35. for i in range(line_num):
  36. source_words = self.source_lines[i].split()
  37. target_words = self.target_lines[i].split()
  38.  
  39. self.H.reset_state()
  40. self.zerograds()
  41. loss = self.loss(source_words, target_words)
  42. loss.backward()
  43. loss.unchain_backward()
  44. self.optimizer.update()
  45.  
  46. if debug:
  47. print("{0} / {1} line finished.".format(i + 1, line_num))
  48.  
  49. def test(self, source_words):
  50. bar_h_i_list = self.h_i_list(source_words, True)
  51. x_i = self.embed_x(Variable(np.array([self.source_word2id['<eos>']], dtype=np.int32), volatile='on'))
  52. h_t = self.H(x_i)
  53. c_t = self.c_t(bar_h_i_list, h_t.data[0], True)
  54.  
  55. result = []
  56. bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
  57. wid = np.argmax(F.softmax(self.W(bar_h_t)).data[0])
  58. result.append(self.target_id2word[wid])
  59.  
  60. loop = 0
  61. while (wid != self.target_word2id['<eos>']) and (loop <= 30):
  62. y_i = self.embed_y(Variable(np.array([wid], dtype=np.int32), volatile='on'))
  63. h_t = self.H(y_i)
  64. c_t = self.c_t(bar_h_i_list, h_t.data, True)
  65.  
  66. bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
  67. wid = np.argmax(F.softmax(self.W(bar_h_t)).data[0])
  68. result.append(self.target_id2word[wid])
  69. loop += 1
  70. return result
  71.  
  72. # 損失を求める
  73. def loss(self, source_words, target_words):
  74. bar_h_i_list = self.h_i_list(source_words)
  75. x_i = self.embed_x(Variable(np.array([self.source_word2id['<eos>']], dtype=np.int32)))
  76. h_t = self.H(x_i)
  77. c_t = self.c_t(bar_h_i_list, h_t.data[0])
  78.  
  79. bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
  80. tx = Variable(np.array([self.target_word2id[target_words[0]]], dtype=np.int32))
  81. accum_loss = F.softmax_cross_entropy(self.W(bar_h_t), tx)
  82. for i in range(len(target_words)):
  83. wid = self.target_word2id[target_words[i]]
  84. y_i = self.embed_y(Variable(np.array([wid], dtype=np.int32)))
  85. h_t = self.H(y_i)
  86. c_t = self.c_t(bar_h_i_list, h_t.data)
  87.  
  88. bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
  89. next_wid = self.target_word2id['<eos>'] if (i == len(target_words) - 1) else self.target_word2id[target_words[i+1]]
  90. tx = Variable(np.array([next_wid], dtype=np.int32))
  91. loss = F.softmax_cross_entropy(self.W(bar_h_t), tx)
  92. accum_loss = loss if accum_loss is None else accum_loss + loss
  93. return accum_loss
  94.  
  95. # h_i のリストを求める
  96. def h_i_list(self, words, test = False):
  97. h_i_list = []
  98. volatile = 'on' if test else 'off'
  99. for word in words:
  100. wid = self.source_word2id[word]
  101. x_i = self.embed_x(Variable(np.array([wid], dtype=np.int32), volatile=volatile))
  102. h_i = self.H(x_i)
  103. h_i_list.append(np.copy(h_i.data[0]))
  104. return h_i_list
  105.  
  106. # context vector c_t を求める
  107. def c_t(self, bar_h_i_list, h_t, test = False):
  108. s = 0.0
  109. for bar_h_i in bar_h_i_list:
  110. s += np.exp(h_t.dot(bar_h_i))
  111.  
  112. c_t = np.zeros(self.embed_size)
  113. for bar_h_i in bar_h_i_list:
  114. alpha_t_i = np.exp(h_t.dot(bar_h_i)) / s
  115. c_t += alpha_t_i * bar_h_i
  116. volatile = 'on' if test else 'off'
  117. c_t = Variable(np.array([c_t]).astype(np.float32), volatile=volatile)
  118. return c_t
  119.  
  120. # 文章リストを読み込む
  121. def load_language(self, filename):
  122. word2id = {}
  123. lines = open(filename).read().split('\n')
  124. for i in range(len(lines)):
  125. sentence = lines[i].split()
  126. for word in sentence:
  127. if word not in word2id:
  128. word2id[word] = len(word2id)
  129. word2id['<eos>'] = len(word2id)
  130. id2word = {v:k for k, v in word2id.items()}
  131. return [lines, word2id, id2word]
  132.  
  133. # モデルを読み込む
  134. def load_model(self, filename):
  135. serializers.load_npz(filename, self)
  136.  
  137. # モデルを書き出す
  138. def save_model(self, filename):
  139. serializers.save_npz(filename, self)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement