Guest User

Untitled

a guest
Apr 19th, 2018
192
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.80 KB | None | 0 0
  1. from __future__ import print_function
  2. from keras.preprocessing import sequence
  3. from keras.models import Sequential
  4. from keras.layers import Dense, Embedding, Dropout
  5. from keras.layers import LSTM, TimeDistributed, Flatten
  6. from keras.datasets import imdb
  7. from keras.callbacks import EarlyStopping, ModelCheckpoint
  8. import numpy as np
  9.  
  10.  
  11. max_features = 20000
  12. maxlen = 80 # cut texts after this number of words (among top max_features most common words)
  13. batch_size = 32
  14. embedding_dim = 100
  15.  
  16. def generate_batch(batchsize):
  17.  
  18. (x_train, y_train), (_,_) = imdb.load_data(num_words=max_features)
  19. print("train_size", x_train.shape)
  20. while True:
  21. for i in range(0, len(x_train), batchsize):
  22. x_batch = x_train[i:(i+batchsize)]
  23. y_batch = y_train[i:(i+batchsize)]
  24. x_batch = sequence.pad_sequences(x_batch, maxlen=maxlen, padding='post')
  25. yield(x_batch, y_batch)
  26.  
  27. def generate_val(valsize):
  28.  
  29. (_,_), (x_test, y_test) = imdb.load_data(num_words=max_features)
  30. print("test_size", x_test.shape)
  31. while True:
  32. for i in range(0, len(x_test), valsize):
  33. x_val = x_test[i:(i+valsize)]
  34. y_val = y_test[i:(i+valsize)]
  35. x_val = sequence.pad_sequences(x_val, maxlen=maxlen, padding='post')
  36. yield(x_val, y_val)
  37.  
  38. print('Build model...')
  39. primary_model = Sequential()
  40. primary_model.add(Embedding(input_dim = max_features,
  41. output_dim = embedding_dim,
  42. trainable=True,
  43. weights=[(np.eye(max_features,embedding_dim))],
  44. mask_zero=True))
  45. primary_model.add(TimeDistributed(Dense(150, use_bias=False)))
  46. primary_model.add(LSTM(128))
  47. primary_model.add(Dense(2, activation='softmax'))
  48. primary_model.summary()
  49. primary_model.compile(loss='sparse_categorical_crossentropy',
  50. optimizer='adam',
  51. metrics=['accuracy'])
  52.  
  53. print('Train...')
  54. filepath = "primeweights-{epoch:02d}-{val_acc:.2f}.hdf5"
  55. checkpoint = ModelCheckpoint(filepath,
  56. verbose=1,
  57. save_best_only=True)
  58. early_stopping_monitor = EarlyStopping(patience=2)
  59.  
  60. primary_model.fit_generator(generate_batch(25),
  61. steps_per_epoch = 1000,
  62. epochs = 2,
  63. callbacks=[early_stopping_monitor],
  64. validation_data=generate_val(25),
  65. validation_steps=1000)
  66.  
  67.  
  68. (_,_), (x_test, y_test) = imdb.load_data(num_words=max_features)
  69. x_test = sequence.pad_sequences(x_test, maxlen=maxlen, padding='post')
  70. score, acc = primary_model.evaluate(x_test, y_test, batch_size=batch_size)
  71. print('Test score:', score)
  72. print('Test accuracy:', acc)
  73.  
  74. primary_model.save('primary_model_imdb.h5')
Add Comment
Please, Sign In to add comment