Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import keras
- import numpy as np
- import matplotlib.pyplot as plt
- from keras.models import Sequential
- from keras.layers import Dense, SimpleRNN, Masking
- from keras.callbacks import Callback
- DTYPE = np.float64
- TIMESTEPS = 50
- class ResetStatesCallback(Callback):
- def __init__(self, n_batches):
- super(ResetStatesCallback, self).__init__()
- self.n_batches = n_batches
- def on_epoch_begin(self, epoch, logs=None):
- self.model.reset_states()
- def on_batch_end(self, batch, logs=None):
- if (batch+1) % self.n_batches == 0:
- self.model.reset_states()
- def pad(sequence, timesteps):
- if sequence.shape[0] % timesteps == 0:
- return sequence
- else:
- return np.pad(sequence, ((0, timesteps - (sequence.shape[0] % timesteps)), (0, 0)), 'constant')
- data = np.loadtxt('data/mg17.csv', delimiter=',', dtype=DTYPE)
- X_data = data[:, [0]]
- Y_data = data[:, [1]]
- trX_data = X_data[:4000, :]
- trY_data = Y_data[:4000, :]
- vlX_data = X_data[4000:5000, :]
- vlY_data = Y_data[4000:5000, :]
- trX = np.reshape(pad(trX_data, TIMESTEPS), (-1, TIMESTEPS, 1))
- trY = np.reshape(pad(trY_data, TIMESTEPS), (-1, TIMESTEPS, 1))
- vlX = np.reshape(pad(vlX_data, TIMESTEPS), (-1, TIMESTEPS, 1))
- vlY = np.reshape(pad(vlY_data, TIMESTEPS), (-1, TIMESTEPS, 1))
- for r in range(5):
- model = Sequential()
- model.add(Masking(batch_input_shape=(1, TIMESTEPS, 1)))
- model.add(SimpleRNN(10, stateful=True, return_sequences=True,
- batch_input_shape=(1, TIMESTEPS, 1)))
- model.add(Dense(1))
- optimizer = keras.optimizers.Adam(lr=0.01)
- model.compile(loss='mean_squared_error',
- optimizer=optimizer)
- history = model.fit(trX, trY, batch_size=1, validation_data=(vlX, vlY),
- shuffle=False, epochs=100, verbose=0,
- callbacks=[ResetStatesCallback(trX.shape[0])])
- plt.clf()
- plt.plot(history.history['loss'])
- plt.plot(history.history['val_loss'])
- plt.xlabel("epoch")
- plt.ylabel("loss")
- plt.legend(["TR", "VL"])
- plt.savefig("keras-mg-rnn-trunk("+str(r)+").png")
- print("TR:", history.history['loss'][-1], "VL:", history.history['val_loss'][-1])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement