SHARE
TWEET

Untitled

a guest May 21st, 2019 68 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. async function trainModel(inputs, outputs, trainingsize, window_size, n_epochs, learning_rate, n_layers, callback){
  2.  
  3.   const input_layer_shape  = window_size;
  4.   const input_layer_neurons = 100;
  5.  
  6.   const rnn_input_layer_features = 10;
  7.   const rnn_input_layer_timesteps = input_layer_neurons / rnn_input_layer_features;
  8.  
  9.   const rnn_input_shape  = [rnn_input_layer_features, rnn_input_layer_timesteps];
  10.   const rnn_output_neurons = 20;
  11.  
  12.   const rnn_batch_size = window_size;
  13.  
  14.   const output_layer_shape = rnn_output_neurons;
  15.   const output_layer_neurons = 1;
  16.  
  17.   const model = tf.sequential();
  18.  
  19.   let X = inputs.slice(0, Math.floor(trainingsize / 100 * inputs.length));
  20.   let Y = outputs.slice(0, Math.floor(trainingsize / 100 * outputs.length));
  21.  
  22.   const xs = tf.tensor2d(X, [X.length, X[0].length]).div(tf.scalar(10));
  23.   const ys = tf.tensor2d(Y, [Y.length, 1]).reshape([Y.length, 1]).div(tf.scalar(10));
  24.  
  25.   model.add(tf.layers.dense({units: input_layer_neurons, inputShape: [input_layer_shape]}));
  26.   model.add(tf.layers.reshape({targetShape: rnn_input_shape}));
  27.  
  28.   let lstm_cells = [];
  29.   for (let index = 0; index < n_layers; index++) {
  30.        lstm_cells.push(tf.layers.lstmCell({units: rnn_output_neurons}));
  31.   }
  32.  
  33.   model.add(tf.layers.rnn({
  34.     cell: lstm_cells,
  35.     inputShape: rnn_input_shape,
  36.     returnSequences: false
  37.   }));
  38.  
  39.   model.add(tf.layers.dense({units: output_layer_neurons, inputShape: [output_layer_shape]}));
  40.  
  41.   model.compile({
  42.     optimizer: tf.train.adam(learning_rate),
  43.     loss: 'meanSquaredError'
  44.   });
  45.  
  46.   const hist = await model.fit(xs, ys,
  47.     { batchSize: rnn_batch_size, epochs: n_epochs, callbacks: {
  48.       onEpochEnd: async (epoch, log) => {
  49.         callback(epoch, log);
  50.       }
  51.     }
  52.   });
  53.  
  54.   return { model: model, stats: hist };
  55. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top