Advertisement
Guest User

Untitled

a guest
May 21st, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.72 KB | None | 0 0
  1. async function trainModel(inputs, outputs, trainingsize, window_size, n_epochs, learning_rate, n_layers, callback){
  2.  
  3. const input_layer_shape = window_size;
  4. const input_layer_neurons = 100;
  5.  
  6. const rnn_input_layer_features = 10;
  7. const rnn_input_layer_timesteps = input_layer_neurons / rnn_input_layer_features;
  8.  
  9. const rnn_input_shape = [rnn_input_layer_features, rnn_input_layer_timesteps];
  10. const rnn_output_neurons = 20;
  11.  
  12. const rnn_batch_size = window_size;
  13.  
  14. const output_layer_shape = rnn_output_neurons;
  15. const output_layer_neurons = 1;
  16.  
  17. const model = tf.sequential();
  18.  
  19. let X = inputs.slice(0, Math.floor(trainingsize / 100 * inputs.length));
  20. let Y = outputs.slice(0, Math.floor(trainingsize / 100 * outputs.length));
  21.  
  22. const xs = tf.tensor2d(X, [X.length, X[0].length]).div(tf.scalar(10));
  23. const ys = tf.tensor2d(Y, [Y.length, 1]).reshape([Y.length, 1]).div(tf.scalar(10));
  24.  
  25. model.add(tf.layers.dense({units: input_layer_neurons, inputShape: [input_layer_shape]}));
  26. model.add(tf.layers.reshape({targetShape: rnn_input_shape}));
  27.  
  28. let lstm_cells = [];
  29. for (let index = 0; index < n_layers; index++) {
  30. lstm_cells.push(tf.layers.lstmCell({units: rnn_output_neurons}));
  31. }
  32.  
  33. model.add(tf.layers.rnn({
  34. cell: lstm_cells,
  35. inputShape: rnn_input_shape,
  36. returnSequences: false
  37. }));
  38.  
  39. model.add(tf.layers.dense({units: output_layer_neurons, inputShape: [output_layer_shape]}));
  40.  
  41. model.compile({
  42. optimizer: tf.train.adam(learning_rate),
  43. loss: 'meanSquaredError'
  44. });
  45.  
  46. const hist = await model.fit(xs, ys,
  47. { batchSize: rnn_batch_size, epochs: n_epochs, callbacks: {
  48. onEpochEnd: async (epoch, log) => {
  49. callback(epoch, log);
  50. }
  51. }
  52. });
  53.  
  54. return { model: model, stats: hist };
  55. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement