Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Small LSTM Network to Generate Text for Alice in Wonderland
- import numpy
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.layers import Dense
- from tensorflow.keras.layers import Dropout
- from tensorflow.keras.layers import LSTM
- from tensorflow.keras.callbacks import ModelCheckpoint
- from tensorflow.keras.utils import to_categorical
- # load ascii text and covert to lowercase
- filename = "wonderland.txt"
- raw_text = open(filename, 'r', encoding='utf-8').read()
- raw_text = raw_text.lower()
- # create mapping of unique chars to integers
- chars = sorted(list(set(raw_text)))
- char_to_int = dict((c, i) for i, c in enumerate(chars))
- # summarize the loaded data
- n_chars = len(raw_text)
- n_vocab = len(chars)
- print "Total Characters: ", n_chars
- print "Total Vocab: ", n_vocab
- # prepare the dataset of input to output pairs encoded as integers
- seq_length = 100
- prediction_length = 2
- dataX = []
- dataY = []
- encoded_text = [char_to_int[char] for char in raw_text]
- for i in range(0, n_chars - seq_length, 1):
- seq_in = encoded_text[i:i + seq_length]
- seq_out = encoded_text[i + seq_length: i + seq_length + prediction_length]
- dataX.append(seq_in)
- dataY.append(seq_out)
- n_patterns = len(dataX)
- print "Total Patterns: ", n_patterns
- # reshape X to be [samples, time steps, features]
- X = numpy.reshape(dataX, (-1, seq_length, prediction_length))
- # normalize
- X = X / float(n_vocab)
- # one hot encode the output variable
- y = to_categorical(dataY)
- # define the LSTM model
- model = Sequential()
- model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
- model.add(Dropout(0.2))
- model.add(Dense(y.shape[1], activation='softmax'))
- model.compile(loss='categorical_crossentropy', optimizer='adam')
- # define the checkpoint
- filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
- checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
- callbacks_list = [checkpoint]
- # fit the model
- model.fit(X, y, epochs=20, batch_size=128, callbacks=callbacks_list)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement