Advertisement
Guest User

Untitled

a guest
Dec 6th, 2016
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.42 KB | None | 0 0
  1. const math = require('./math');
  2.  
  3. class RNN {
  4. constructor(nIn, nHidden, nOut, truncatedTime = 3, learningRate = 0.1, activation = math.fn.tanh, rng = Math.random) {
  5. this.nIn = nIn;
  6. this.nHidden = nHidden;
  7. this.nOut = nOut;
  8. this.truncatedTime = truncatedTime;
  9. this.learningRate = learningRate;
  10. this.activation = activation;
  11.  
  12. // this._activationOutput = (nOut === 1) ? math.fn.sigmoid : math.fn.softmax;
  13.  
  14. this.U = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]); // input -> hidden
  15. this.V = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nOut, nHidden]); // hidden -> output
  16. this.W = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]); // hidden -> hidden
  17.  
  18. this.b = math.array.zeros(nHidden); // hidden bias
  19. this.c = math.array.zeros(nOut); // output bias
  20. }
  21.  
  22. // x: number[][] ( number[time][index] )
  23. forwardProp(x) {
  24. let timeLength = x.length;
  25.  
  26. let s = math.array.zeros(timeLength, this.nHidden);
  27. let u = math.array.zeros(timeLength, this.nHidden);
  28. let y = math.array.zeros(timeLength, this.nOut);
  29. let v = math.array.zeros(timeLength, this.nOut);
  30.  
  31. for (let t = 0; t < timeLength; t++) {
  32. let _st = (t === 0) ? math.array.zeros(this.nHidden) : s[t - 1];
  33. u[t] = math.add(math.add(math.dot(this.U, x[t]), math.dot(this.W, _st)), this.b);
  34. s[t] = this.activation(u[t]);
  35.  
  36. v[t] = math.add(math.dot(this.V, s[t]), this.c)
  37. // y[t] = this._activationOutput(this.v[t]);
  38. y[t] = math.fn.linear(v[t]);
  39. }
  40.  
  41. return {
  42. s: s,
  43. u: u,
  44. y: y,
  45. v: v
  46. };
  47. }
  48.  
  49. backProp(x, label) {
  50. let dU = math.array.zeros(this.nHidden, this.nIn);
  51. let dV = math.array.zeros(this.nOut, this.nHidden);
  52. let dW = math.array.zeros(this.nHidden, this.nHidden);
  53. let db = math.array.zeros(this.nHidden);
  54. let dc = math.array.zeros(this.nOut);
  55.  
  56. let timeLength = x.length;
  57. let units = this.forwardProp(x);
  58. let s = units.s;
  59. let u = units.u;
  60. let y = units.y;
  61. let v = units.v;
  62.  
  63. // let eo = math.mul(math.sub(o, label), this._activationOutput.grad(this.v));
  64. let eo = math.mul(math.sub(y, label), math.fn.linear.grad(v));
  65. let eh = math.array.zeros(timeLength, this.nHidden);
  66.  
  67. for (let t = timeLength - 1; t >= 0; t--) {
  68. dV = math.add(dV, math.outer(eo[t], s[t]));
  69. dc = math.add(dc, eo[t]);
  70. eh[t] = math.mul(math.dot(eo[t], this.V), this.activation.grad(u[t]));
  71.  
  72. for (let z = 0; z < this.truncatedTime; z++) {
  73. if (t - z < 0) {
  74. break;
  75. }
  76.  
  77. dU = math.add(dU, math.outer(eh[t - z], x[t - z]));
  78. db = math.add(db, eh[t - z]);
  79.  
  80. if (t - z - 1 >= 0) {
  81. dW = math.add(dW, math.outer(eh[t - z], s[t - z - 1]));
  82. eh[t - z - 1] = math.mul(math.dot(eh[t - z], this.W), this.activation.grad(u[t - z - 1]));
  83. }
  84. }
  85. }
  86.  
  87. return {
  88. grad: {
  89. U: dU,
  90. V: dV,
  91. W: dW,
  92. b: db,
  93. c: dc
  94. }
  95. };
  96. }
  97.  
  98. sgd(x, label, learningRate) {
  99. learningRate = learningRate || this.learningRate;
  100. let grad = this.backProp(x, label).grad;
  101.  
  102. this.U = math.sub(this.U, math.mul(learningRate, grad.U));
  103. this.V = math.sub(this.V, math.mul(learningRate, grad.V));
  104. this.W = math.sub(this.W, math.mul(learningRate, grad.W));
  105. }
  106.  
  107. predict(x) {
  108. let units = this.forwardProp(x);
  109. return units.y;
  110. }
  111. }
  112.  
  113.  
  114. module.exports = RNN;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement