Guest User

Untitled

a guest
Oct 12th, 2017
114
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import keras.backend as K
  2. import multiprocessing
  3. import tensorflow as tf
  4. import numpy as np
  5. import h5py
  6.  
  7. from gensim.models.word2vec import Word2Vec
  8. from gensim.utils import simple_preprocess
  9. from gensim.parsing.preprocessing import STOPWORDS
  10. from gensim.parsing.porter import PorterStemmer
  11.  
  12. from keras.callbacks import EarlyStopping
  13. from keras.models import Sequential
  14. from keras.layers.core import Dense, Dropout, Flatten
  15. from keras.layers.convolutional import Conv1D
  16. from keras.optimizers import Adam
  17.  
  18. from sklearn.metrics import precision_recall_fscore_support, classification_report
  19. from sklearn.model_selection import KFold
  20.  
  21.  
  22. # Set random seed (for reproducibility)
  23. np.random.seed(1000)
  24.  
  25. use_gpu = True
  26.  
  27. config = tf.ConfigProto(intra_op_parallelism_threads=multiprocessing.cpu_count(),
  28. inter_op_parallelism_threads=multiprocessing.cpu_count(),
  29. allow_soft_placement=True,
  30. device_count={'CPU': 1,
  31. 'GPU': 1 if use_gpu else 0})
  32.  
  33. session = tf.Session(config=config)
  34. K.set_session(session)
  35.  
  36. dataset = '/tmp/test_dataset.txt'
  37.  
  38. corpus = []
  39. labels = []
  40.  
  41. # Parse texts and sentiments
  42. file_id = open(dataset, 'r')
  43. counter = 0
  44. for line in file_id:
  45. # Sentiment (0 = Negative, 1 = Positive)
  46. counter += 1
  47. labels.append(int(line[0]))
  48. # Text
  49. text = line[2:]
  50. if text.startswith('"'):
  51. text = text[1:]
  52. if text.endswith('"'):
  53. text = text[::-1]
  54. corpus.append(text.lower())
  55.  
  56. corpus = corpus[:1000000]
  57. labels = labels[:1000000]
  58. corpus_size = len(corpus)
  59. print('Corpus size: {}'.format(len(corpus)))
  60. print('1st text {}'.format(corpus[1]))
  61. print('label {}'.format(labels[1]))
  62. print('2nd text {}'.format(corpus[2]))
  63. print('label {}'.format(labels[2]))
  64.  
  65. # Tokenize and stem
  66. stemmer = PorterStemmer()
  67. tokenized_corpus = []
  68.  
  69. for i, text in enumerate(corpus):
  70. tokens = [stemmer.stem(t) for t in simple_preprocess(text) if t not in STOPWORDS or not t.startswith('@') or not t.startswith('#') \
  71. or not t.startswith('<br') or not t.startswith('http')]
  72. tokenized_corpus.append(tokens)
  73.  
  74. # Gensim Word2Vec model
  75. size = 512
  76. window = 10
  77.  
  78. # Create Word2Vec
  79. word2vec = Word2Vec(sentences=tokenized_corpus,
  80. size=size,
  81. window=window,
  82. negative=20,
  83. iter=50,
  84. seed=1000,
  85. workers=multiprocessing.cpu_count())
  86.  
  87. X_vecs = word2vec.wv
  88. del word2vec
  89. del corpus
  90.  
  91. # Compute average and max text length
  92. avg_length = 0.0
  93. max_length = 0
  94.  
  95. for text in tokenized_corpus:
  96. if len(text) > max_length:
  97. max_length = len(text)
  98. avg_length += float(len(text))
  99.  
  100. print('Length tokinzed corpus: {}'.format(len(tokenized_corpus)))
  101. print('Average text length: {}'.format(avg_length / float(len(tokenized_corpus))))
  102. print('Max text length: {}'.format(max_length))
  103.  
  104. # Text max length (number of tokens)
  105. max_text_length = 15
  106.  
  107. # Generate random indexes
  108. indexes = set(np.random.choice(len(tokenized_corpus), corpus_size, replace=False))
  109.  
  110. X = np.zeros((corpus_size, max_text_length, size), dtype=K.floatx())
  111. Y = np.zeros((corpus_size, 2), dtype=np.int32)
  112.  
  113. for i, index in enumerate(indexes):
  114. for t, token in enumerate(tokenized_corpus[index]):
  115. if t >= max_text_length:
  116. break
  117.  
  118. if token not in X_vecs:
  119. continue
  120.  
  121. X[i, t, :] = X_vecs[token]
  122. Y[i, :] = [1.0, 0.0] if labels[index] == 0 else [0.0, 1.0]
  123.  
  124. def create_model():
  125. model = Sequential()
  126.  
  127. model.add(Conv1D(32, kernel_size=3, activation='elu', padding='same', input_shape=(max_text_length, size)))
  128. model.add(Conv1D(32, kernel_size=3, activation='elu', padding='same'))
  129. model.add(Conv1D(32, kernel_size=3, activation='elu', padding='same'))
  130. model.add(Conv1D(32, kernel_size=3, activation='elu', padding='same'))
  131. model.add(Dropout(0.25))
  132.  
  133. model.add(Conv1D(32, kernel_size=2, activation='elu', padding='same'))
  134. model.add(Conv1D(32, kernel_size=2, activation='elu', padding='same'))
  135. model.add(Conv1D(32, kernel_size=2, activation='elu', padding='same'))
  136. model.add(Conv1D(32, kernel_size=2, activation='elu', padding='same'))
  137. model.add(Dropout(0.25))
  138.  
  139. model.add(Flatten())
  140.  
  141. model.add(Dense(256, activation='elu'))
  142. model.add(Dense(256, activation='elu'))
  143. model.add(Dropout(0.5))
  144.  
  145. model.add(Dense(2, activation='softmax'))
  146.  
  147. # Compile the model
  148. model.compile(loss='categorical_crossentropy',
  149. optimizer=Adam(lr=0.0001, decay=1e-6),
  150. metrics=['accuracy'])
  151. print(model.summary())
  152. return model
  153.  
  154. print('loading model')
  155. model = create_model()
  156. # load model weights
  157. model.load_weights('/tmp/Model_weights_sentimet_without_sw_and_with_stem.h5')
  158.  
  159. scores = model.evaluate(X, Y, verbose=0)
  160. print("%s: %.2f%%" % (model.metrics_names[1], scores[1] * 100))
  161. Y_pred = model.predict(X)
  162.  
  163. Y_label = []
  164. Y_pred_label = []
  165. for label in Y:
  166. if label[0] > label[1]:
  167. Y_label.append(0)
  168. else:
  169. Y_label.append(1)
  170.  
  171. for label in Y_pred:
  172. if label[0] > label[1]:
  173. Y_pred_label.append(0)
  174. else:
  175. Y_pred_label.append(1)
  176.  
  177. precision, recall, fscore, support = precision_recall_fscore_support(Y_label, Y_pred_label, labels=['0','1'])
  178. print(classification_report(Y_label, Y_pred_label))
  179.  
  180. print('precision', precision)
  181. print('recall', recall)
  182. print('fscore', fscore)
  183. print('support', support)
RAW Paste Data