Guest User

Untitled

a guest
Apr 25th, 2018
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.83 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. """ This example trains an LSTM to predict the next number
  4. in a sequence 0, 0, 1, 1, 0, 0, 1, 1, ...
  5. Training is done on very short sequences of length 3.
  6. Loss is computed on all but the last predicted value for simplicity.
  7.  
  8. When REMEMBER_STATE is True, the previous LSTM state is transfered
  9. to the next training step and the network reaches almost zero error.
  10. When REMEMBER STATE is False, the network has no way to predict the
  11. second element from the first, so it outputs 0.5. It is still able to
  12. predict the third element.
  13. """
  14.  
  15. import random
  16. import numpy as np
  17.  
  18. import tensorflow as tf
  19. from tensorflow.contrib import rnn
  20.  
  21. REMEMBER_STATE = True
  22.  
  23. NUM_HIDDEN = 10
  24. BATCH_SIZE = 100
  25. LENGTH = 3
  26.  
  27. saved_c = tf.get_variable("saved_c", shape=[BATCH_SIZE, NUM_HIDDEN], dtype=tf.float32)
  28. saved_h = tf.get_variable("saved_h", shape=[BATCH_SIZE, NUM_HIDDEN], dtype=tf.float32)
  29. mlp = tf.layers.Dense(1)
  30.  
  31. x = tf.placeholder(tf.float32, [BATCH_SIZE, LENGTH, 1])
  32. xs = tf.unstack(x, LENGTH, axis=1)
  33. initial_c=tf.placeholder(tf.float32, [BATCH_SIZE, NUM_HIDDEN])
  34. initial_h=tf.placeholder(tf.float32, [BATCH_SIZE, NUM_HIDDEN])
  35.  
  36. initial_state = rnn.LSTMStateTuple(c=initial_c, h=initial_h)
  37. cell = rnn.BasicLSTMCell(NUM_HIDDEN)
  38.  
  39. # outputs - a list of length LENGTH, each element a tensor of shape [BATCH_SIZE, NUM_HIDDEN]
  40. # states - LSTMStateTuple with both c and h having shape [BATCH_SIZE, NUM_HIDDEN]
  41. outputs, states = rnn.static_rnn(cell, inputs=xs, initial_state=initial_state)
  42.  
  43. assign_c = tf.assign(saved_c, states.c)
  44. assign_h = tf.assign(saved_h, states.h)
  45. with tf.control_dependencies([assign_c, assign_h]):
  46. assign_op = tf.no_op()
  47.  
  48. def loss(inputs, outputs):
  49. loss = 0
  50. # Predictions for first example in the batch
  51. predictions = []
  52. for output, labels in zip(outputs, tf.unstack(inputs[:, 1:, :], axis=1)):
  53. # prediction shape = [BATCH_SIZE, 1]
  54. prediction = mlp(output)
  55. predictions.append(prediction[0, 0])
  56. loss += tf.sqrt(tf.losses.mean_squared_error(labels=labels,
  57. predictions=prediction))
  58. return loss, tf.stack(predictions)
  59.  
  60. floss, predictions = loss(x, outputs)
  61. train_op = tf.train.AdamOptimizer().minimize(floss)
  62.  
  63. def input_gen():
  64. repeats = 2
  65. nums = 2
  66.  
  67. cache = {}
  68. template = np.repeat(range(nums) * 2 * LENGTH, repeats=repeats)
  69. def numpy_cache(pos):
  70. key = pos % (nums * repeats)
  71. if key not in cache:
  72. cache[key] = template[key:(key + LENGTH)]
  73. return cache[key]
  74.  
  75. positions = [random.randint(0, nums * repeats) for _ in xrange(BATCH_SIZE)]
  76. def get_input():
  77. result = np.zeros([BATCH_SIZE, LENGTH, 1])
  78. for i in xrange(BATCH_SIZE):
  79. result[i, :, 0] = numpy_cache(positions[i])
  80. positions[i] += LENGTH
  81. return result
  82.  
  83. return get_input
  84.  
  85.  
  86. with tf.Session() as sess:
  87. sess.run(tf.global_variables_initializer())
  88.  
  89. c_val = np.zeros([BATCH_SIZE, NUM_HIDDEN])
  90. h_val = np.zeros([BATCH_SIZE, NUM_HIDDEN])
  91.  
  92. generator = input_gen()
  93. losses = []
  94. for i in xrange(10000):
  95. x_val = generator()
  96. if REMEMBER_STATE:
  97. c_val, h_val, loss_val, _, _, first_pred = sess.run(
  98. [saved_c, saved_h, floss, train_op, assign_op, predictions],
  99. feed_dict={x: x_val,
  100. initial_c: c_val,
  101. initial_h: h_val})
  102. else:
  103. loss_val, _, first_pred = sess.run(
  104. [floss, train_op, predictions],
  105. feed_dict={x: x_val,
  106. initial_c: c_val,
  107. initial_h: h_val})
  108. losses.append(loss_val)
  109.  
  110. if i % 101 == 0 and len(losses) >= 100:
  111. print "iteration:", i
  112. print "loss:", sum(losses[-100:]) / 100.0
  113. print "predictions on first example:", first_pred
  114. print "input:", x_val[0, :, 0]
  115. print
Add Comment
Please, Sign In to add comment