Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import datetime
- import pickle
- from collections import Counter
- import random
- import numpy as np
- import tensorflow as tf
- def createTrainingMatrices(conversationFileName, wList, maxLen):
- conversationDictionary = np.load(conversationFileName).item()
- numExamples = len(conversationDictionary)
- xTrain = np.zeros((numExamples, maxLen), dtype='int32')
- yTrain = np.zeros((numExamples, maxLen), dtype='int32')
- for index, (key, value) in enumerate(conversationDictionary.iteritems()):
- # Will store integerized representation of strings here (initialized as padding)
- encoderMessage = np.full((maxLen), wList.index('<pad>'), dtype='int32')
- decoderMessage = np.full((maxLen), wList.index('<pad>'), dtype='int32')
- # Getting all the individual words in the strings
- keySplit = key
- valueSplit = value
- keyCount = len(keySplit)
- valueCount = len(valueSplit)
- # Throw out sequences that are too long or are empty
- t1 = keyCount > (maxLen - 1)
- e1 = maxLen - 1
- e2 = valueCount
- t2 = e1 > e2
- t3 = valueCount == 0
- t4 = keyCount == 0
- if t1 or t2 or t3 or t4:
- continue
- # Integerize the encoder string
- for validx, word in enumerate(keySplit):
- if validx > 99:
- continue
- try:
- decoderMessage[validx] = wList.index(word)
- except ValueError:
- decoderMessage[validx] = wList.index('WRONG')
- # encoderMessage[0] = wList.index(keySplit)
- encoderMessage[maxLen - 1] = wList.index('<EOS>')
- # Integerize the decoder string
- for valueIndex, word in enumerate(valueSplit):
- if valueIndex > 99:
- continue
- try:
- decoderMessage[valueIndex] = wList.index(word)
- except ValueError:
- decoderMessage[valueIndex] = wList.index('WRONG')
- decoderMessage[maxLen - 1] = wList.index('<EOS>')
- xTrain[index] = encoderMessage
- yTrain[index] = decoderMessage
- # Remove rows with all zeros
- yTrain = yTrain[~np.all(yTrain == 0, axis=1)]
- xTrain = xTrain[~np.all(xTrain == 0, axis=1)]
- numExamples = xTrain.shape[0]
- return numExamples, xTrain, yTrain
- def process_dataset(filename):
- openedFile = open(filename, 'r')
- allLines = openedFile.readlines()
- giantTrmp = ""
- for line in allLines:
- giantTrmp += line
- occurenceDict = Counter(giantTrmp.split())
- return giantTrmp, occurenceDict
- def getTestInput(inputMessage, wList, maxLen):
- encoderMessage = np.full((maxLen), wList.index('<pad>'), dtype='int32')
- inputSplit = inputMessage.lower().split()
- for index, word in enumerate(inputSplit):
- try:
- encoderMessage[index] = wList.index(word)
- except ValueError:
- continue
- encoderMessage[index + 1] = wList.index('<EOS>')
- encoderMessage = encoderMessage[::-1]
- encoderMessageList = []
- for num in encoderMessage:
- encoderMessageList.append([num])
- return encoderMessageList
- def getTrainingBatch(localXTrain, localYTrain, localBatchSize, maxLen):
- num = random.randint(0, numTrainingExamples - localBatchSize - 1)
- arr = localXTrain[num:num + localBatchSize]
- labels = localYTrain[num:num + localBatchSize]
- # Reversing the order of encoder string apparently helps as per 2014 paper
- reversedList = list(arr)
- for index, example in enumerate(reversedList):
- reversedList[index] = list(reversed(example))
- # Lagged labels are for the training input into the decoder
- laggedLabels = []
- EOStokenIndex = wordList.index('<EOS>')
- padTokenIndex = wordList.index('<pad>')
- for example in labels:
- eosFound = np.argwhere(example == EOStokenIndex)[0]
- shiftedExample = np.roll(example, 1)
- shiftedExample[0] = EOStokenIndex
- # The EOS token was already at the end, so no need for pad
- if eosFound != (maxLen - 1):
- shiftedExample[eosFound + 1] = padTokenIndex
- laggedLabels.append(shiftedExample)
- # Need to transpose these
- reversedList = np.asarray(reversedList).T.tolist()
- labels = labels.T.tolist()
- laggedLabels = np.asarray(laggedLabels).T.tolist()
- return reversedList, labels, laggedLabels
- def idsToSentence(ids, wList):
- EOStokenIndex = wList.index('<EOS>')
- padTokenIndex = wList.index('<pad>')
- myStr = ""
- listOfResponses = []
- for num in ids:
- if num[0] == EOStokenIndex or num[0] == padTokenIndex:
- listOfResponses.append(myStr)
- myStr = ""
- else:
- myStr = myStr + wList[num[0]] + " "
- if myStr:
- listOfResponses.append(myStr)
- listOfResponses = [i for i in listOfResponses if i]
- return listOfResponses
- fullCorpus, datasetDictionary = process_dataset('speeches.txt')
- print('Finished parsing and cleaning dataset')
- wordList = list(datasetDictionary.keys())
- with open("wordList.txt", "wb") as fp:
- pickle.dump(wordList, fp)
- # Hyperparamters
- batchSize = 24
- maxEncoderLength = 100
- maxDecoderLength = maxEncoderLength
- lstmUnits = 112
- embeddingDim = lstmUnits
- numLayersLSTM = 15
- numIterations = 500000
- with open("wordList.txt", "rb") as fp:
- wordList = pickle.load(fp)
- vocabSize = len(wordList)
- padVector = np.zeros((1, 100), dtype='int32')
- EOSVector = np.ones((1, 100), dtype='int32')
- wordList.append('<pad>')
- wordList.append('<EOS>')
- vocabSize = vocabSize + 2
- combinedDictionary = {}
- #Larger sentence association
- counter = 0;
- counter = 0
- for wor in datasetDictionary.keys():
- print(wor)
- if counter < len(datasetDictionary) - 100:
- combinedDictionary[" ".join(Counter(
- (wor + " ".join(datasetDictionary.keys()[counter:counter + (random.randint(1, 99))])).split()).keys())] \
- = Counter(datasetDictionary.keys()[counter:counter + 100]).keys()
- counter = counter + 1
- else:
- combinedDictionary[" ".join(Counter(
- (wor + " ".join(datasetDictionary.keys()[counter:counter + (random.randint(1, 99))])).split()).keys())] \
- = Counter(datasetDictionary.keys()[counter - 100:len(datasetDictionary)]).keys()
- counter = counter + 1
- np.save('conversationDictionary.npy', combinedDictionary)
- numTrainingExamples, xT, yT = createTrainingMatrices('conversationDictionary.npy', wordList, maxEncoderLength)
- np.save('Seq2SeqXTrain.npy', xT)
- np.save('Seq2SeqYTrain.npy', yT)
- print('Finished creating training matrices')
- tf.reset_default_graph()
- with tf.device('/gpu:0'):
- # Create the placeholders
- encoderInputs = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxEncoderLength)]
- decoderLabels = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxDecoderLength)]
- decoderInputs = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxDecoderLength)]
- feedPrevious = tf.placeholder(tf.bool)
- encoderLSTM = tf.nn.rnn_cell.BasicLSTMCell(lstmUnits, state_is_tuple=True)
- decoderOutputs, decoderFinalState = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(encoderInputs,
- decoderInputs,
- encoderLSTM, vocabSize, vocabSize,
- embeddingDim,
- feed_previous=feedPrevious)
- decoderPrediction = tf.argmax(decoderOutputs, 2)
- lossWeights = [tf.ones_like(l, dtype=tf.float32) for l in decoderLabels]
- loss = tf.contrib.legacy_seq2seq.sequence_loss(decoderOutputs, decoderLabels, lossWeights, vocabSize)
- optimizer = tf.train.AdamOptimizer(1e-4).minimize(loss)
- sess = tf.Session()
- saver = tf.train.Saver()
- sess.run(tf.global_variables_initializer())
- tf.summary.scalar('Loss', loss)
- merged = tf.summary.merge_all()
- logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/"
- writer = tf.summary.FileWriter(logdir, sess.graph)
- encoderTestStrings = ["This is america",
- "Dont want ya slippin",
- "What do you think about hillary",
- "I think Immigrants are just the best dont you think?",
- "bruh!"
- ]
- zeroVector = np.zeros((1), dtype='int32')
- # app id 479575825911539
- # app key 39c0a729924d43ca954066e00fb1be13
- for i in range(numIterations):
- encoderTrain, decoderTargetTrain, decoderInputTrain = getTrainingBatch(xT, yT, batchSize, maxEncoderLength)
- feedDict = {encoderInputs[t]: encoderTrain[t] for t in range(maxEncoderLength)}
- feedDict.update({decoderLabels[t]: decoderTargetTrain[t] for t in range(maxDecoderLength)})
- feedDict.update({decoderInputs[t]: decoderInputTrain[t] for t in range(maxDecoderLength)})
- feedDict.update({feedPrevious: False})
- curLoss, _, pred = sess.run([loss, optimizer, decoderPrediction], feed_dict=feedDict)
- if i % 50 == 0:
- print('Current loss:', curLoss, 'at iteration', i)
- summary = sess.run(merged, feed_dict=feedDict)
- writer.add_summary(summary, i)
- if i % 25 == 0 and i != 0:
- num = random.randint(0, len(encoderTestStrings) - 1)
- print(encoderTestStrings[num])
- inputVector = getTestInput(encoderTestStrings[num], wordList, maxEncoderLength)
- feedDict = {encoderInputs[t]: inputVector[t] for t in range(maxEncoderLength)}
- feedDict.update({decoderLabels[t]: zeroVector for t in range(maxDecoderLength)})
- feedDict.update({decoderInputs[t]: zeroVector for t in range(maxDecoderLength)})
- feedDict.update({feedPrevious: True})
- ids = (sess.run(decoderPrediction, feed_dict=feedDict))
- print(Counter(idsToSentence(ids, wordList)).keys())
- if i % 10000 == 0 and i != 0:
- savePath = saver.save(sess, "models/pretrained_seq2seq.ckpt", global_step=i)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement