Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- from tensorflow.keras.layers import Dense, LSTM, Masking, TimeDistributed, RepeatVector, Lambda, BatchNormalization
- from tensorflow.keras.preprocessing.sequence import pad_sequences
- from tensorflow.keras.models import Model
- from tensorflow.keras.models import Sequential
- import numpy as np
- def repeat_vector(args):
- """ Repeat vector n times """
- layer_to_repeat = args[0]
- sequence_layer = args[1]
- return RepeatVector(tf.keras.backend.shape(sequence_layer)[1])(layer_to_repeat)
- #(samples,timesteps,features) samples=4,features=3, timesteps=variable length
- train_X = np.array([
- [[0, 1, 2], [9, 8, 7],[3, 6, 8]],
- [[3, 4, 5]],
- [[6, 7, 8], [6, 5, 4],[1, 7, 4]],
- [[9, 0, 1], [3, 7, 4]]
- ])
- test_X = np.array([
- [[0, 1, 4]],
- [[6, 7, 4], [7, 3, 0], [5, 8, 9], [0, 2, 4]]
- ])
- train_Y = np.array([0, 1, 1, 0])
- n_feat = 3
- # padding
- train_X = pad_sequences(train_X, padding='post')
- test_X = pad_sequences(test_X, padding='post')
- model = Sequential()
- # masking to handle variable length size
- inputs = tf.keras.Input(shape=(None, n_feat))
- masked_input = Masking(mask_value=0)(inputs)
- # encoder
- encoder = LSTM(32, activation='tanh', return_sequences=True)(masked_input)
- encoder = LSTM(units=16, activation='tanh', return_sequences=False)(encoder)
- # decoder
- decoder = Lambda(repeat_vector, output_shape=(None, n_feat)) ([encoder, masked_input])
- decoder = LSTM(16, activation='tanh', return_sequences=True)(decoder)
- decoder = LSTM(units=32, activation='tanh', return_sequences=True)(decoder)
- decoder = TimeDistributed(Dense(units=n_feat))(decoder)
- model = Model(inputs=inputs, outputs=decoder)
- model.compile(optimizer='rmsprop', loss='mse')
- mask_vector_layer = Model(inputs=model.inputs, outputs=model.layers[6].output)
- print(mask_vector_layer.summary())
- mask_vector_layer_output = mask_vector_layer.predict(test_X)
- print(mask_vector_layer_output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement