Guest User

Untitled

a guest
Jun 10th, 2018
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.33 KB | None | 0 0
  1. require 'numo/narray'
  2. require 'pp'
  3. require 'wrong/assert'
  4. include Wrong::Assert
  5.  
  6. def sigmoid(x)
  7. 1.0 / (1 + Numo::NMath.exp(-x))
  8. end
  9.  
  10. def sigmoid_derivative(values)
  11. values * (1 - values)
  12. end
  13.  
  14. def tanh_derivative(values)
  15. 1.0 - values**2
  16. end
  17.  
  18. # createst uniform random array w/ values in [a,b) and shape args
  19. def rand_arr(a, b, *args)
  20. Numo::DFloat.new(*args).rand * (b - a) + a
  21. end
  22.  
  23. class LstmParam
  24. attr_accessor :mem_cell_ct, :x_dim, :wg, :wi, :wf, :wo, :bg, :bi, :bf, :bo, :wg_diff, :wi_diff, :wf_diff, :wo_diff,
  25. :bg_diff, :bi_diff, :bf_diff, :bo_diff
  26. def initialize(mem_cell_ct, x_dim)
  27. @mem_cell_ct = mem_cell_ct
  28. @x_dim = x_dim
  29. puts @concat_len = x_dim + mem_cell_ct
  30.  
  31. # weight matrices
  32. @wg = rand_arr(-0.1, 0.1, @mem_cell_ct, @concat_len)
  33. @wi = rand_arr(-0.1, 0.1, @mem_cell_ct, @concat_len)
  34. @wf = rand_arr(-0.1, 0.1, @mem_cell_ct, @concat_len)
  35. @wo = rand_arr(-0.1, 0.1, @mem_cell_ct, @concat_len)
  36.  
  37. # bias terms
  38. @bg = rand_arr(-0.1, 0.1, @mem_cell_ct)
  39. @bi = rand_arr(-0.1, 0.1, @mem_cell_ct)
  40. @bf = rand_arr(-0.1, 0.1, @mem_cell_ct)
  41. @bo = rand_arr(-0.1, 0.1, @mem_cell_ct)
  42.  
  43. # diffs (derivative of loss function w.r.t. all parameters)
  44. @wg_diff = Numo::DFloat.zeros(@mem_cell_ct, @concat_len)
  45. @wi_diff = Numo::DFloat.zeros(@mem_cell_ct, @concat_len)
  46. @wf_diff = Numo::DFloat.zeros(@mem_cell_ct, @concat_len)
  47. @wo_diff = Numo::DFloat.zeros(@mem_cell_ct, @concat_len)
  48. @bg_diff = Numo::DFloat.zeros(@mem_cell_ct)
  49. @bi_diff = Numo::DFloat.zeros(@mem_cell_ct)
  50. @bf_diff = Numo::DFloat.zeros(@mem_cell_ct)
  51. @bo_diff = Numo::DFloat.zeros(@mem_cell_ct)
  52. end
  53.  
  54. def apply_diff(lr = 1)
  55. @wg -= lr * @wg_diff
  56. @wi -= lr * @wi_diff
  57. @wf -= lr * @wf_diff
  58. @wo -= lr * @wo_diff
  59. @bg -= lr * @bg_diff
  60. @bi -= lr * @bi_diff
  61. @bf -= lr * @bf_diff
  62. @bo -= lr * @bo_diff
  63.  
  64. # reset diffs to zero of similar shape
  65. @wg_diff = @wg.new_zeros
  66. @wi_diff = @wi.new_zeros
  67. @wf_diff = @wf.new_zeros
  68. @wo_diff = @wo.new_zeros
  69. @bg_diff = @bg.new_zeros
  70. @bi_diff = @bi.new_zeros
  71. @bf_diff = @bf.new_zeros
  72. @bo_diff = @bo.new_zeros
  73. end
  74. end
  75.  
  76. class LstmState
  77. attr_accessor :g, :i, :f, :o, :s, :h, :bottom_diff_h, :bottom_diff_s
  78. def initialize(mem_cell_ct, _x_dim)
  79. @g = Numo::DFloat.zeros(mem_cell_ct)
  80. @i = Numo::DFloat.zeros(mem_cell_ct)
  81. @f = Numo::DFloat.zeros(mem_cell_ct)
  82. @o = Numo::DFloat.zeros(mem_cell_ct)
  83. @s = Numo::DFloat.zeros(mem_cell_ct)
  84. @h = Numo::DFloat.zeros(mem_cell_ct)
  85. @bottom_diff_h = @h.new_zeros
  86. @bottom_diff_s = @s.new_zeros
  87. end
  88. end
  89.  
  90. class LstmNode
  91. attr_accessor :state, :param, :xc, :s_prev, :h_prev
  92. def initialize(lstm_param, lstm_state)
  93. # store reference to parameters and to activations
  94. @state = lstm_state
  95. @param = lstm_param
  96. # non-recurrent input concatenated with recurrent input
  97. @xc = nil
  98. end
  99.  
  100. def bottom_data_is(x, s_prev = nil, h_prev = nil)
  101. # if this is the first lstm node in the network
  102. if s_prev.nil? then s_prev = @state.s.new_zeros end
  103. if h_prev.nil? then h_prev = @state.h.new_zeros end
  104.  
  105. # save data for use in backprop
  106. @s_prev = s_prev
  107. @h_prev = h_prev
  108.  
  109. # concatenate x(t) and h(t-1)
  110. xc = Numo::NArray.hstack([x, h_prev])
  111. @state.g = Numo::NMath.tanh((@param.wg.dot xc) + @param.bg)
  112. @state.i = sigmoid((@param.wi.dot xc) + @param.bi)
  113. @state.f = sigmoid((@param.wf.dot xc) + @param.bf)
  114. @state.o = sigmoid((@param.wo.dot xc) + @param.bo)
  115. @state.s = @state.g * @state.i + s_prev * @state.f
  116. @state.h = @state.s * @state.o
  117.  
  118. @xc = xc
  119. end
  120.  
  121. def top_diff_is(top_diff_h, top_diff_s)
  122. # notice that top_diff_s is carried along the constant error carousel
  123. ds = @state.o * top_diff_h + top_diff_s
  124. dopt = @state.s * top_diff_h
  125. di = @state.g * ds
  126. dg = @state.i * ds
  127. df = @s_prev * ds
  128.  
  129. # diffs w.r.t. vector inside sigma / tanh function
  130. di_input = sigmoid_derivative(@state.i) * di
  131. df_input = sigmoid_derivative(@state.f) * df
  132. do_input = sigmoid_derivative(@state.o) * dopt
  133. dg_input = tanh_derivative(@state.g) * dg
  134.  
  135. # diffs w.r.t. inputs
  136. @param.wi_diff += di_input.outer(@xc)
  137. @param.wf_diff += df_input.outer(@xc)
  138. @param.wo_diff += do_input.outer(@xc)
  139. @param.wg_diff += dg_input.outer(@xc)
  140. @param.bi_diff += di_input
  141. @param.bf_diff += df_input
  142. @param.bo_diff += do_input
  143. @param.bg_diff += dg_input
  144.  
  145. # compute bottom diff
  146. dxc = @xc.new_zeros
  147. dxc += (@param.wi.transpose).dot di_input
  148. dxc += (@param.wf.transpose).dot df_input
  149. dxc += (@param.wo.transpose).dot do_input
  150. dxc += (@param.wg.transpose).dot dg_input
  151.  
  152. # save bottom diffs
  153. @state.bottom_diff_s = ds * @state.f
  154. @state.bottom_diff_h = dxc[@param.x_dim..-1]
  155. end
  156. end
  157.  
  158. class LstmNetwork
  159. attr_accessor :lstm_param, :lstm_node_list, :x_list
  160. def initialize(lstm_param)
  161. @lstm_param = lstm_param
  162. @lstm_node_list = []
  163. # input sequence
  164. @x_list = []
  165. end
  166.  
  167. def y_list_is(y_list, loss_layer)
  168. # """
  169. # Updates diffs by setting target sequence
  170. # with corresponding loss layer.
  171. # Will *NOT* update parameters. To update parameters,
  172. # call @lstm_param.apply_diff()
  173. # """
  174.  
  175. @lstm_param.apply_diff
  176.  
  177. # Provided by gem 'wrong'. You can roll your own, but this paints a nice report.
  178. assert { (y_list.size) == (@x_list.size) }
  179.  
  180. idx = (@x_list.size) - 1
  181. # first node only gets diffs from label ...
  182. loss = loss_layer.loss(@lstm_node_list[idx].state.h, y_list[idx])
  183. diff_h = loss_layer.bottom_diff(@lstm_node_list[idx].state.h, y_list[idx])
  184. # here s is not affecting loss due to h(t+1), hence we set equal to zero
  185. diff_s = Numo::DFloat.zeros(@lstm_param.mem_cell_ct)
  186. @lstm_node_list[idx].top_diff_is(diff_h, diff_s)
  187. idx -= 1
  188.  
  189. while idx >= 0
  190. loss += loss_layer.loss(@lstm_node_list[idx].state.h, y_list[idx])
  191. diff_h = loss_layer.bottom_diff(@lstm_node_list[idx].state.h, y_list[idx])
  192. diff_h += @lstm_node_list[idx + 1].state.bottom_diff_h
  193. diff_s = @lstm_node_list[idx + 1].state.bottom_diff_s
  194. @lstm_node_list[idx].top_diff_is(diff_h, diff_s)
  195. idx -= 1
  196. end
  197.  
  198. loss
  199. end
  200.  
  201. def x_list_clear
  202. @x_list = []
  203. end
  204.  
  205. def x_list_add(x)
  206. @x_list.push(x)
  207. if (@x_list.size) > (@lstm_node_list.size)
  208. # need to add new lstm node, create new state mem
  209. lstm_state = LstmState.new(@lstm_param.mem_cell_ct, @lstm_param.x_dim)
  210. @lstm_node_list.push(LstmNode.new(@lstm_param, lstm_state))
  211. end
  212.  
  213. # get index of most recent x input
  214. idx = (@x_list.size) - 1
  215.  
  216. if idx.zero?
  217. # no recurrent inputs yet
  218. @lstm_node_list[idx].bottom_data_is(x)
  219. else
  220. s_prev = @lstm_node_list[idx - 1].state.s
  221. h_prev = @lstm_node_list[idx - 1].state.h
  222. @lstm_node_list[idx].bottom_data_is(x, s_prev, h_prev)
  223. end
  224.  
  225. end
  226.  
  227. end
  228.  
  229.  
  230.  
  231.  
  232.  
  233.  
  234.  
  235.  
  236.  
  237. # TESTING LSTM
  238.  
  239.  
  240.  
  241.  
  242.  
  243.  
  244.  
  245.  
  246.  
  247.  
  248.  
  249.  
  250.  
  251.  
  252. class ToyLossLayer
  253. # """
  254. # Computes square loss with first element of hidden layer array.
  255. # """
  256.  
  257. def self.loss(pred, label)
  258. (pred[0] - label)**2
  259. end
  260.  
  261. def self.bottom_diff(pred, label)
  262. diff = pred.new_zeros
  263. diff[0] = 2 * (pred[0] - label)
  264. diff
  265. end
  266. end
  267.  
  268.  
  269.  
  270.  
  271. def example_0
  272. # parameters for input data dimension and lstm cell count
  273. mem_cell_ct = 100
  274. x_dim = 50
  275. lstm_param = LstmParam.new(mem_cell_ct, x_dim)
  276. lstm_net = LstmNetwork.new(lstm_param)
  277. y_list = [-0.5, 0.2, 0.1, -0.5]
  278. input_val_arr = Array.new(y_list.size) { Numo::DFloat.new(x_dim).rand }
  279.  
  280. 100.times do |cur_iter|
  281. pp "cur iter: " + cur_iter.inspect
  282. y_list.size.times do |ind|
  283. lstm_net.x_list_add(input_val_arr[ind])
  284. pp "y_pred[" + ind.inspect + "] : " + (lstm_net.lstm_node_list[ind].state.h[0]).inspect
  285. end
  286.  
  287. loss = lstm_net.y_list_is(y_list, ToyLossLayer)
  288. pp "loss: " + loss.inspect
  289. lstm_param.apply_diff(lr = 0.1)
  290. lstm_net.x_list_clear
  291.  
  292. end
  293. end
  294.  
  295.  
  296. if __FILE__ == $0
  297. example_0
  298. end
Add Comment
Please, Sign In to add comment