Guest User


a guest
Jul 11th, 2016
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.04 KB | None | 0 0
  1. '''
  2. Recurrent network example. Trains a 2 layered LSTM network to learn
  3. text from a user-provided input file. The network can then be used to generate
  4. text using a short string as seed (refer to the variable generation_phrase).
  5. This example is partly based on Andrej Karpathy's blog
  6. (
  7. and a similar example in the Keras package (
  8. The inputs to the network are batches of sequences of characters and the corresponding
  9. targets are the characters in the text shifted to the right by one.
  10. Assuming a sequence length of 5, a training point for a text file
  11. "The quick brown fox jumps over the lazy dog" would be
  12. INPUT : 'T','h','e',' ','q'
  13. OUTPUT: 'u'
  15. The loss function compares (via categorical crossentropy) the prediction
  16. with the output/target.
  18. Also included is a function to generate text using the RNN given the first
  19. character.
  21. About 20 or so epochs are necessary to generate text that "makes sense".
  23. Written by @keskarnitish
  24. Pre-processing of text uses snippets of Karpathy's code (BSD License)
  25. '''
  27. from __future__ import print_function
  29. import numpy as np
  30. import theano
  31. import theano.tensor as T
  32. import lasagne
  33. import random
  34. import urllib2 # For downloading the sample text file. You won't need this if you are providing your own file.
  36. #import generate_data
  38. # try:
  39. # in_text = urllib2.urlopen('').read()
  40. # # You can also use your own file
  41. # # The file must be a simple text file.
  42. # # Simply edit the file name below and uncomment the line.
  43. # # in_text = open('your_file.txt', 'r').read()
  44. # in_text = in_text.decode("utf-8-sig").encode("utf-8")
  45. # except Exception as e:
  46. # print("Please verify the location of the input file/URL.")
  47. # print("A sample txt file can be downloaded from")
  48. # raise IOError('Unable to Read Text')
  50. # This snippet loads the text file and creates dictionaries to
  51. # encode characters into a vector-space representation and vice-versa.
  52. # chars = list(set(in_text))
  53. # data_size, vocab_size = len(in_text), len(chars)
  54. # char_to_ix = {ch: i for i, ch in enumerate(chars)}
  55. # ix_to_char = {i: ch for i, ch in enumerate(chars)}
  57. data_size = 10000
  59. # Lasagne Seed for Reproducibility
  60. lasagne.random.set_rng(np.random.RandomState(1))
  62. # Sequence Length
  63. SEQ_LENGTH = 20
  65. # Vocab size for proteins
  66. vocab_size = 20
  68. # Categories for proteins
  69. num_categories = 3
  71. # Number of units in the two hidden (LSTM) layers
  72. N_HIDDEN = 512
  74. # Optimization learning rate
  75. LEARNING_RATE = .01
  77. # All gradients above this will be clipped
  78. GRAD_CLIP = 100
  80. # How often should we check the output?
  81. PRINT_FREQ = 10
  83. # Number of epochs to train the net
  84. NUM_EPOCHS = 50
  86. # Batch Size
  87. BATCH_SIZE = 128
  90. # def gen_data(p, batch_size=BATCH_SIZE, data=in_text, return_target=True):
  91. # '''
  92. # This function produces a semi-redundant batch of training samples from the location 'p' in the provided string (data).
  93. # For instance, assuming SEQ_LENGTH = 5 and p=0, the function would create batches of
  94. # 5 characters of the string (starting from the 0th character and stepping by 1 for each semi-redundant batch)
  95. # as the input and the next character as the target.
  96. # To make this clear, let us look at a concrete example. Assume that SEQ_LENGTH = 5, p = 0 and BATCH_SIZE = 2
  97. # If the input string was "The quick brown fox jumps over the lazy dog.",
  98. # For the first data point,
  99. # x (the inputs to the neural network) would correspond to the encoding of 'T','h','e',' ','q'
  100. # y (the targets of the neural network) would be the encoding of 'u'
  101. # For the second point,
  102. # x (the inputs to the neural network) would correspond to the encoding of 'h','e',' ','q', 'u'
  103. # y (the targets of the neural network) would be the encoding of 'i'
  104. # The data points are then stacked (into a three-dimensional tensor of size (batch_size,SEQ_LENGTH,vocab_size))
  105. # and returned.
  106. # Notice that there is overlap of characters between the batches (hence the name, semi-redundant batch).
  107. # '''
  108. # x = np.zeros((batch_size, SEQ_LENGTH, vocab_size))
  109. # y = np.zeros(batch_size)
  110. #
  111. # for n in range(batch_size):
  112. # ptr = n
  113. # for i in range(SEQ_LENGTH):
  114. # x[n, i, char_to_ix[data[p + ptr + i]]] = 1.
  115. # if (return_target):
  116. # y[n] = char_to_ix[data[p + ptr + SEQ_LENGTH]]
  117. # return x, np.array(y, dtype='int32')
  120. def generate_batch(batch_size=50, seq_length=20):
  121. X = np.zeros((batch_size, seq_length, vocab_size), dtype='int32')
  122. y = np.zeros((batch_size, seq_length, num_categories), dtype='int32')
  123. z = np.zeros((batch_size, seq_length), dtype='int32')
  124. a = np.zeros(batch_size, dtype='int32')
  126. for i in range(batch_size):
  127. for j in range(seq_length):
  128. Xval = random.randint(0, 19)
  130. ## This is the very simple function that we hope the LSTM model will learn:
  131. ## i.e. characters in the interval [0, 6] result in category 0
  132. ## characters in the interval [7, 12] result in category 1
  133. ## characters in the interval [13, 19] result in category 2
  134. if Xval <= 6:
  135. yval = 0
  136. elif 5 < Xval <= 12:
  137. yval = 1
  138. else:
  139. yval = 2
  141. X[i][j][Xval] = 1
  142. y[i][yval] = 1
  143. a[i] = yval
  144. return X, a
  146. X_seed, y_seed = generate_batch(batch_size=1)
  147. #generation_phrase = "The quick brown fox jumps" # This phrase will be used as seed to generate text.
  148. generation_phrase = X_seed[0]
  150. def main(num_epochs=NUM_EPOCHS):
  151. print("Building network ...")
  153. # First, we build the network, starting with an input layer
  154. # Recurrent layers expect input of shape
  155. # (batch size, SEQ_LENGTH, num_features)
  157. l_in = lasagne.layers.InputLayer(shape=(None, None, vocab_size))
  159. # We now build the LSTM layer which takes l_in as the input layer
  160. # We clip the gradients at GRAD_CLIP to prevent the problem of exploding gradients.
  162. l_forward_1 = lasagne.layers.LSTMLayer(
  163. l_in, N_HIDDEN, grad_clipping=GRAD_CLIP,
  164. nonlinearity=lasagne.nonlinearities.tanh)
  166. l_forward_2 = lasagne.layers.LSTMLayer(
  167. l_forward_1, N_HIDDEN, grad_clipping=GRAD_CLIP,
  168. nonlinearity=lasagne.nonlinearities.tanh,
  169. only_return_final=True)
  171. # The output of l_forward_2 of shape (batch_size, N_HIDDEN) is then passed through the softmax nonlinearity to
  172. # create probability distribution of the prediction
  173. # The output of this stage is (batch_size, num_categories)
  174. l_out = lasagne.layers.DenseLayer(l_forward_2, num_units=num_categories, W=lasagne.init.Normal(),
  175. nonlinearity=lasagne.nonlinearities.softmax)
  177. # Theano tensor for the targets
  178. target_values = T.ivector('target_output')
  179. #target_values = T.imatrix('target_ouput')
  181. # lasagne.layers.get_output produces a variable for the output of the net
  182. network_output = lasagne.layers.get_output(l_out)
  184. # The loss function is calculated as the mean of the (categorical) cross-entropy between the prediction and target.
  185. cost = T.nnet.categorical_crossentropy(network_output, target_values).mean()
  187. # Retrieve all parameters from the network
  188. all_params = lasagne.layers.get_all_params(l_out, trainable=True)
  190. # Compute AdaGrad updates for training
  191. print("Computing updates ...")
  192. updates = lasagne.updates.adagrad(cost, all_params, LEARNING_RATE)
  194. # Theano functions for training and computing cost
  195. print("Compiling functions ...")
  196. train = theano.function([l_in.input_var, target_values], cost, updates=updates, allow_input_downcast=True)
  197. compute_cost = theano.function([l_in.input_var, target_values], cost, allow_input_downcast=True)
  199. # In order to generate text from the network, we need the probability distribution of the next character given
  200. # the state of the network and the input (a seed).
  201. # In order to produce the probability distribution of the prediction, we compile a function called probs.
  203. probs = theano.function([l_in.input_var], network_output, allow_input_downcast=True)
  205. # The next function generates text given a phrase of length at least SEQ_LENGTH.
  206. # The phrase is set using the variable generation_phrase.
  207. # The optional input "N" is used to set the number of characters of text to predict.
  209. # def try_it_out(N=200):
  210. # '''
  211. # This function uses the user-provided string "generation_phrase" and current state of the RNN generate text.
  212. # The function works in three steps:
  213. # 1. It converts the string set in "generation_phrase" (which must be over SEQ_LENGTH characters long)
  214. # to encoded format. We use the gen_data function for this. By providing the string and asking for a single batch,
  215. # we are converting the first SEQ_LENGTH characters into encoded form.
  216. # 2. We then use the LSTM to predict the next character and store it in a (dynamic) list sample_ix. This is done by using the 'probs'
  217. # function which was compiled above. Simply put, given the output, we compute the probabilities of the target and pick the one
  218. # with the highest predicted probability.
  219. # 3. Once this character has been predicted, we construct a new sequence using all but first characters of the
  220. # provided string and the predicted character. This sequence is then used to generate yet another character.
  221. # This process continues for "N" characters.
  222. # To make this clear, let us again look at a concrete example.
  223. # Assume that SEQ_LENGTH = 5 and generation_phrase = "The quick brown fox jumps".
  224. # We initially encode the first 5 characters ('T','h','e',' ','q'). The next character is then predicted (as explained in step 2).
  225. # Assume that this character was 'J'. We then construct a new sequence using the last 4 (=SEQ_LENGTH-1) characters of the previous
  226. # sequence ('h','e',' ','q') , and the predicted letter 'J'. This new sequence is then used to compute the next character and
  227. # the process continues.
  228. # '''
  229. #
  230. # assert (len(generation_phrase) >= SEQ_LENGTH)
  231. # sample_ix = []
  232. #
  233. # x, _ = generate_batch(len(generation_phrase) - SEQ_LENGTH, 1, generation_phrase, 0)
  234. #
  235. # for i in range(N):
  236. # # Pick the character that got assigned the highest probability
  237. # ix = np.argmax(probs(x).ravel())
  238. # # Alternatively, to sample from the distribution instead:
  239. # # ix = np.random.choice(np.arange(vocab_size), p=probs(x).ravel())
  240. # sample_ix.append(ix)
  241. # x[:, 0:SEQ_LENGTH - 1, :] = x[:, 1:, :]
  242. # x[:, SEQ_LENGTH - 1, :] = 0
  243. # x[0, SEQ_LENGTH - 1, sample_ix[-1]] = 1.
  244. #
  245. # random_snippet = generation_phrase + ''.join(ix_to_char[ix] for ix in sample_ix)
  246. # print("----\n %s \n----" % random_snippet)
  248. print("Training ...")
  249. #print("Seed used for text generation is: " + generation_phrase)
  250. p = 0
  251. try:
  252. for it in xrange(data_size * num_epochs / BATCH_SIZE):
  253. #try_it_out() # Generate text using the p^th character as the start.
  255. avg_cost = 0;
  256. for i in range(PRINT_FREQ):
  257. x, y = generate_batch(BATCH_SIZE, SEQ_LENGTH)
  259. # print(p)
  260. p += SEQ_LENGTH + BATCH_SIZE - 1
  261. if (p + BATCH_SIZE + SEQ_LENGTH >= data_size):
  262. print('Carriage Return')
  263. p = 0;
  265. avg_cost += train(x, y)
  266. print("War were declared: " + str(i))
  267. print("Epoch {} average loss = {}".format(it * 1.0 * PRINT_FREQ / data_size * BATCH_SIZE,
  268. avg_cost / PRINT_FREQ))
  270. except KeyboardInterrupt:
  271. pass
  274. if __name__ == '__main__':
  275. main()
Add Comment
Please, Sign In to add comment