Advertisement
jokeris

trunc keras

Jan 5th, 2018
42
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.19 KB | None
  1. import keras
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from keras.models import Sequential
  5. from keras.layers import Dense, SimpleRNN, Masking
  6. from keras.callbacks import Callback
  7.  
  8. DTYPE = np.float64
  9. TIMESTEPS = 50
  10.  
  11. class ResetStatesCallback(Callback):
  12.     def __init__(self, n_batches):
  13.         super(ResetStatesCallback, self).__init__()
  14.         self.n_batches = n_batches
  15.  
  16.     def on_epoch_begin(self, epoch, logs=None):
  17.         self.model.reset_states()
  18.  
  19.     def on_batch_end(self, batch, logs=None):
  20.         if (batch+1) % self.n_batches == 0:
  21.             self.model.reset_states()
  22.  
  23. def pad(sequence, timesteps):
  24.     if sequence.shape[0] % timesteps == 0:
  25.         return sequence
  26.     else:
  27.         return np.pad(sequence, ((0, timesteps - (sequence.shape[0] % timesteps)), (0, 0)), 'constant')
  28.  
  29.  
  30. data = np.loadtxt('data/mg17.csv', delimiter=',', dtype=DTYPE)
  31. X_data = data[:, [0]]
  32. Y_data = data[:, [1]]
  33. trX_data = X_data[:4000, :]
  34. trY_data = Y_data[:4000, :]
  35. vlX_data = X_data[4000:5000, :]
  36. vlY_data = Y_data[4000:5000, :]
  37.  
  38. trX = np.reshape(pad(trX_data, TIMESTEPS), (-1, TIMESTEPS, 1))
  39. trY = np.reshape(pad(trY_data, TIMESTEPS), (-1, TIMESTEPS, 1))
  40. vlX = np.reshape(pad(vlX_data, TIMESTEPS), (-1, TIMESTEPS, 1))
  41. vlY = np.reshape(pad(vlY_data, TIMESTEPS), (-1, TIMESTEPS, 1))
  42.  
  43. for r in range(5):
  44.     model = Sequential()
  45.     model.add(Masking(batch_input_shape=(1, TIMESTEPS, 1)))
  46.     model.add(SimpleRNN(10, stateful=True, return_sequences=True,
  47.                         batch_input_shape=(1, TIMESTEPS, 1)))
  48.     model.add(Dense(1))
  49.  
  50.     optimizer = keras.optimizers.Adam(lr=0.01)
  51.     model.compile(loss='mean_squared_error',
  52.                   optimizer=optimizer)
  53.  
  54.     history = model.fit(trX, trY, batch_size=1, validation_data=(vlX, vlY),
  55.                         shuffle=False, epochs=100, verbose=0,
  56.                          callbacks=[ResetStatesCallback(trX.shape[0])])
  57.  
  58.     plt.clf()
  59.     plt.plot(history.history['loss'])
  60.     plt.plot(history.history['val_loss'])
  61.     plt.xlabel("epoch")
  62.     plt.ylabel("loss")
  63.     plt.legend(["TR", "VL"])
  64.     plt.savefig("keras-mg-rnn-trunk("+str(r)+").png")
  65.     print("TR:", history.history['loss'][-1], "VL:", history.history['val_loss'][-1])
Advertisement
RAW Paste Data Copied
Advertisement