Advertisement
Guest User

Untitled

a guest
Feb 16th, 2020
164
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.17 KB | None | 0 0
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import datetime
  6. import pickle
  7. from collections import Counter
  8. import random
  9. import numpy as np
  10. import tensorflow as tf
  11.  
  12.  
  13. def createTrainingMatrices(conversationFileName, wList, maxLen):
  14. conversationDictionary = np.load(conversationFileName).item()
  15. numExamples = len(conversationDictionary)
  16. xTrain = np.zeros((numExamples, maxLen), dtype='int32')
  17. yTrain = np.zeros((numExamples, maxLen), dtype='int32')
  18. for index, (key, value) in enumerate(conversationDictionary.iteritems()):
  19. # Will store integerized representation of strings here (initialized as padding)
  20. encoderMessage = np.full((maxLen), wList.index('<pad>'), dtype='int32')
  21. decoderMessage = np.full((maxLen), wList.index('<pad>'), dtype='int32')
  22. # Getting all the individual words in the strings
  23. keySplit = key
  24. valueSplit = value
  25. keyCount = len(keySplit)
  26. valueCount = len(valueSplit)
  27. # Throw out sequences that are too long or are empty
  28. t1 = keyCount > (maxLen - 1)
  29. e1 = maxLen - 1
  30. e2 = valueCount
  31. t2 = e1 > e2
  32. t3 = valueCount == 0
  33. t4 = keyCount == 0
  34. if t1 or t2 or t3 or t4:
  35. continue
  36. # Integerize the encoder string
  37. for validx, word in enumerate(keySplit):
  38. if validx > 99:
  39. continue
  40. try:
  41. decoderMessage[validx] = wList.index(word)
  42. except ValueError:
  43. decoderMessage[validx] = wList.index('WRONG')
  44. # encoderMessage[0] = wList.index(keySplit)
  45. encoderMessage[maxLen - 1] = wList.index('<EOS>')
  46. # Integerize the decoder string
  47. for valueIndex, word in enumerate(valueSplit):
  48. if valueIndex > 99:
  49. continue
  50. try:
  51. decoderMessage[valueIndex] = wList.index(word)
  52. except ValueError:
  53. decoderMessage[valueIndex] = wList.index('WRONG')
  54. decoderMessage[maxLen - 1] = wList.index('<EOS>')
  55. xTrain[index] = encoderMessage
  56. yTrain[index] = decoderMessage
  57. # Remove rows with all zeros
  58. yTrain = yTrain[~np.all(yTrain == 0, axis=1)]
  59. xTrain = xTrain[~np.all(xTrain == 0, axis=1)]
  60. numExamples = xTrain.shape[0]
  61. return numExamples, xTrain, yTrain
  62.  
  63. def process_dataset(filename):
  64. openedFile = open(filename, 'r')
  65. allLines = openedFile.readlines()
  66. giantTrmp = ""
  67. for line in allLines:
  68. giantTrmp += line
  69. occurenceDict = Counter(giantTrmp.split())
  70. return giantTrmp, occurenceDict
  71.  
  72.  
  73. def getTestInput(inputMessage, wList, maxLen):
  74. encoderMessage = np.full((maxLen), wList.index('<pad>'), dtype='int32')
  75. inputSplit = inputMessage.lower().split()
  76. for index, word in enumerate(inputSplit):
  77. try:
  78. encoderMessage[index] = wList.index(word)
  79. except ValueError:
  80. continue
  81. encoderMessage[index + 1] = wList.index('<EOS>')
  82. encoderMessage = encoderMessage[::-1]
  83. encoderMessageList = []
  84. for num in encoderMessage:
  85. encoderMessageList.append([num])
  86. return encoderMessageList
  87.  
  88.  
  89. def getTrainingBatch(localXTrain, localYTrain, localBatchSize, maxLen):
  90. num = random.randint(0, numTrainingExamples - localBatchSize - 1)
  91. arr = localXTrain[num:num + localBatchSize]
  92. labels = localYTrain[num:num + localBatchSize]
  93. # Reversing the order of encoder string apparently helps as per 2014 paper
  94. reversedList = list(arr)
  95. for index, example in enumerate(reversedList):
  96. reversedList[index] = list(reversed(example))
  97.  
  98. # Lagged labels are for the training input into the decoder
  99. laggedLabels = []
  100. EOStokenIndex = wordList.index('<EOS>')
  101. padTokenIndex = wordList.index('<pad>')
  102. for example in labels:
  103. eosFound = np.argwhere(example == EOStokenIndex)[0]
  104. shiftedExample = np.roll(example, 1)
  105. shiftedExample[0] = EOStokenIndex
  106. # The EOS token was already at the end, so no need for pad
  107. if eosFound != (maxLen - 1):
  108. shiftedExample[eosFound + 1] = padTokenIndex
  109. laggedLabels.append(shiftedExample)
  110.  
  111. # Need to transpose these
  112. reversedList = np.asarray(reversedList).T.tolist()
  113. labels = labels.T.tolist()
  114. laggedLabels = np.asarray(laggedLabels).T.tolist()
  115. return reversedList, labels, laggedLabels
  116.  
  117.  
  118. def idsToSentence(ids, wList):
  119. EOStokenIndex = wList.index('<EOS>')
  120. padTokenIndex = wList.index('<pad>')
  121. myStr = ""
  122. listOfResponses = []
  123. for num in ids:
  124. if num[0] == EOStokenIndex or num[0] == padTokenIndex:
  125. listOfResponses.append(myStr)
  126. myStr = ""
  127. else:
  128. myStr = myStr + wList[num[0]] + " "
  129. if myStr:
  130. listOfResponses.append(myStr)
  131. listOfResponses = [i for i in listOfResponses if i]
  132. return listOfResponses
  133.  
  134.  
  135. fullCorpus, datasetDictionary = process_dataset('speeches.txt')
  136. print('Finished parsing and cleaning dataset')
  137. wordList = list(datasetDictionary.keys())
  138. with open("wordList.txt", "wb") as fp:
  139. pickle.dump(wordList, fp)
  140.  
  141. # Hyperparamters
  142. batchSize = 24
  143. maxEncoderLength = 100
  144. maxDecoderLength = maxEncoderLength
  145. lstmUnits = 112
  146. embeddingDim = lstmUnits
  147. numLayersLSTM = 15
  148. numIterations = 500000
  149.  
  150. with open("wordList.txt", "rb") as fp:
  151. wordList = pickle.load(fp)
  152.  
  153. vocabSize = len(wordList)
  154.  
  155. padVector = np.zeros((1, 100), dtype='int32')
  156. EOSVector = np.ones((1, 100), dtype='int32')
  157. wordList.append('<pad>')
  158. wordList.append('<EOS>')
  159. vocabSize = vocabSize + 2
  160.  
  161. combinedDictionary = {}
  162.  
  163.  
  164. #Larger sentence association
  165. counter = 0;
  166.  
  167.  
  168.  
  169. counter = 0
  170. for wor in datasetDictionary.keys():
  171. print(wor)
  172. if counter < len(datasetDictionary) - 100:
  173. combinedDictionary[" ".join(Counter(
  174. (wor + " ".join(datasetDictionary.keys()[counter:counter + (random.randint(1, 99))])).split()).keys())] \
  175. = Counter(datasetDictionary.keys()[counter:counter + 100]).keys()
  176. counter = counter + 1
  177. else:
  178. combinedDictionary[" ".join(Counter(
  179. (wor + " ".join(datasetDictionary.keys()[counter:counter + (random.randint(1, 99))])).split()).keys())] \
  180. = Counter(datasetDictionary.keys()[counter - 100:len(datasetDictionary)]).keys()
  181. counter = counter + 1
  182.  
  183. np.save('conversationDictionary.npy', combinedDictionary)
  184.  
  185. numTrainingExamples, xT, yT = createTrainingMatrices('conversationDictionary.npy', wordList, maxEncoderLength)
  186. np.save('Seq2SeqXTrain.npy', xT)
  187. np.save('Seq2SeqYTrain.npy', yT)
  188. print('Finished creating training matrices')
  189.  
  190. tf.reset_default_graph()
  191.  
  192. with tf.device('/gpu:0'):
  193. # Create the placeholders
  194. encoderInputs = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxEncoderLength)]
  195. decoderLabels = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxDecoderLength)]
  196. decoderInputs = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxDecoderLength)]
  197. feedPrevious = tf.placeholder(tf.bool)
  198.  
  199. encoderLSTM = tf.nn.rnn_cell.BasicLSTMCell(lstmUnits, state_is_tuple=True)
  200. decoderOutputs, decoderFinalState = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(encoderInputs,
  201. decoderInputs,
  202. encoderLSTM, vocabSize, vocabSize,
  203. embeddingDim,
  204. feed_previous=feedPrevious)
  205.  
  206. decoderPrediction = tf.argmax(decoderOutputs, 2)
  207.  
  208. lossWeights = [tf.ones_like(l, dtype=tf.float32) for l in decoderLabels]
  209. loss = tf.contrib.legacy_seq2seq.sequence_loss(decoderOutputs, decoderLabels, lossWeights, vocabSize)
  210. optimizer = tf.train.AdamOptimizer(1e-4).minimize(loss)
  211.  
  212. sess = tf.Session()
  213. saver = tf.train.Saver()
  214.  
  215. sess.run(tf.global_variables_initializer())
  216. tf.summary.scalar('Loss', loss)
  217. merged = tf.summary.merge_all()
  218. logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/"
  219. writer = tf.summary.FileWriter(logdir, sess.graph)
  220.  
  221. encoderTestStrings = ["This is america",
  222. "Dont want ya slippin",
  223. "What do you think about hillary",
  224. "I think Immigrants are just the best dont you think?",
  225. "bruh!"
  226. ]
  227.  
  228. zeroVector = np.zeros((1), dtype='int32')
  229. # app id 479575825911539
  230. # app key 39c0a729924d43ca954066e00fb1be13
  231. for i in range(numIterations):
  232.  
  233. encoderTrain, decoderTargetTrain, decoderInputTrain = getTrainingBatch(xT, yT, batchSize, maxEncoderLength)
  234. feedDict = {encoderInputs[t]: encoderTrain[t] for t in range(maxEncoderLength)}
  235. feedDict.update({decoderLabels[t]: decoderTargetTrain[t] for t in range(maxDecoderLength)})
  236. feedDict.update({decoderInputs[t]: decoderInputTrain[t] for t in range(maxDecoderLength)})
  237. feedDict.update({feedPrevious: False})
  238.  
  239. curLoss, _, pred = sess.run([loss, optimizer, decoderPrediction], feed_dict=feedDict)
  240.  
  241. if i % 50 == 0:
  242. print('Current loss:', curLoss, 'at iteration', i)
  243. summary = sess.run(merged, feed_dict=feedDict)
  244. writer.add_summary(summary, i)
  245. if i % 25 == 0 and i != 0:
  246. num = random.randint(0, len(encoderTestStrings) - 1)
  247. print(encoderTestStrings[num])
  248. inputVector = getTestInput(encoderTestStrings[num], wordList, maxEncoderLength)
  249. feedDict = {encoderInputs[t]: inputVector[t] for t in range(maxEncoderLength)}
  250. feedDict.update({decoderLabels[t]: zeroVector for t in range(maxDecoderLength)})
  251. feedDict.update({decoderInputs[t]: zeroVector for t in range(maxDecoderLength)})
  252. feedDict.update({feedPrevious: True})
  253. ids = (sess.run(decoderPrediction, feed_dict=feedDict))
  254. print(Counter(idsToSentence(ids, wordList)).keys())
  255.  
  256. if i % 10000 == 0 and i != 0:
  257. savePath = saver.save(sess, "models/pretrained_seq2seq.ckpt", global_step=i)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement