Advertisement
Guest User

Text Generation LSTM

a guest
Aug 2nd, 2020
161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.99 KB | None | 0 0
  1. # Small LSTM Network to Generate Text for Alice in Wonderland
  2. import numpy
  3. from tensorflow.keras.models import Sequential
  4. from tensorflow.keras.layers import Dense
  5. from tensorflow.keras.layers import Dropout
  6. from tensorflow.keras.layers import LSTM
  7. from tensorflow.keras.callbacks import ModelCheckpoint
  8. from tensorflow.keras.utils import to_categorical
  9. # load ascii text and covert to lowercase
  10. filename = "wonderland.txt"
  11. raw_text = open(filename, 'r', encoding='utf-8').read()
  12. raw_text = raw_text.lower()
  13. # create mapping of unique chars to integers
  14. chars = sorted(list(set(raw_text)))
  15. char_to_int = dict((c, i) for i, c in enumerate(chars))
  16. # summarize the loaded data
  17. n_chars = len(raw_text)
  18. n_vocab = len(chars)
  19. print "Total Characters: ", n_chars
  20. print "Total Vocab: ", n_vocab
  21. # prepare the dataset of input to output pairs encoded as integers
  22. seq_length = 100
  23. prediction_length = 2
  24. dataX = []
  25. dataY = []
  26. encoded_text = [char_to_int[char] for char in raw_text]
  27. for i in range(0, n_chars - seq_length, 1):
  28.     seq_in = encoded_text[i:i + seq_length]
  29.     seq_out = encoded_text[i + seq_length: i + seq_length + prediction_length]
  30.     dataX.append(seq_in)
  31.     dataY.append(seq_out)
  32. n_patterns = len(dataX)
  33. print "Total Patterns: ", n_patterns
  34. # reshape X to be [samples, time steps, features]
  35. X = numpy.reshape(dataX, (-1, seq_length, prediction_length))
  36. # normalize
  37. X = X / float(n_vocab)
  38. # one hot encode the output variable
  39. y = to_categorical(dataY)
  40. # define the LSTM model
  41. model = Sequential()
  42. model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
  43. model.add(Dropout(0.2))
  44. model.add(Dense(y.shape[1], activation='softmax'))
  45. model.compile(loss='categorical_crossentropy', optimizer='adam')
  46. # define the checkpoint
  47. filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
  48. checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
  49. callbacks_list = [checkpoint]
  50. # fit the model
  51. model.fit(X, y, epochs=20, batch_size=128, callbacks=callbacks_list)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement