Guest User

Untitled

a guest
Jan 16th, 2017
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.60 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import input_data
  4. import Image
  5. from util import tile_raster_images
  6.  
  7.  
  8. def sample_prob(probs):
  9. return tf.nn.relu(
  10. tf.sign(
  11. probs - tf.random_uniform(tf.shape(probs))))
  12.  
  13. alpha = 1.0
  14. batchsize = 100
  15.  
  16. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  17. trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images,\
  18. mnist.test.labels
  19.  
  20. X = tf.placeholder("float", [None, 784])
  21. Y = tf.placeholder("float", [None, 10])
  22.  
  23. rbm_w = tf.placeholder("float", [784, 500])
  24. rbm_vb = tf.placeholder("float", [784])
  25. rbm_hb = tf.placeholder("float", [500])
  26. h0 = sample_prob(tf.nn.sigmoid(tf.matmul(X, rbm_w) + rbm_hb))
  27. v1 = sample_prob(tf.nn.sigmoid(
  28. tf.matmul(h0, tf.transpose(rbm_w)) + rbm_vb))
  29. h1 = tf.nn.sigmoid(tf.matmul(v1, rbm_w) + rbm_hb)
  30. w_positive_grad = tf.matmul(tf.transpose(X), h0)
  31. w_negative_grad = tf.matmul(tf.transpose(v1), h1)
  32. update_w = rbm_w + alpha * \
  33. (w_positive_grad - w_negative_grad) / tf.to_float(tf.shape(X)[0])
  34. update_vb = rbm_vb + alpha * tf.reduce_mean(X - v1, 0)
  35. update_hb = rbm_hb + alpha * tf.reduce_mean(h0 - h1, 0)
  36.  
  37. h_sample = sample_prob(tf.nn.sigmoid(tf.matmul(X, rbm_w) + rbm_hb))
  38. v_sample = sample_prob(tf.nn.sigmoid(
  39. tf.matmul(h_sample, tf.transpose(rbm_w)) + rbm_vb))
  40. err = X - v_sample
  41. err_sum = tf.reduce_mean(err * err)
  42.  
  43. sess = tf.Session()
  44. init = tf.initialize_all_variables()
  45. sess.run(init)
  46.  
  47. n_w = np.zeros([784, 500], np.float32)
  48. n_vb = np.zeros([784], np.float32)
  49. n_hb = np.zeros([500], np.float32)
  50. o_w = np.zeros([784, 500], np.float32)
  51. o_vb = np.zeros([784], np.float32)
  52. o_hb = np.zeros([500], np.float32)
  53. print sess.run(
  54. err_sum, feed_dict={X: trX, rbm_w: o_w, rbm_vb: o_vb, rbm_hb: o_hb})
  55.  
  56. for start, end in zip(
  57. range(0, len(trX), batchsize), range(batchsize, len(trX), batchsize)):
  58. batch = trX[start:end]
  59. n_w = sess.run(update_w, feed_dict={
  60. X: batch, rbm_w: o_w, rbm_vb: o_vb, rbm_hb: o_hb})
  61. n_vb = sess.run(update_vb, feed_dict={
  62. X: batch, rbm_w: o_w, rbm_vb: o_vb, rbm_hb: o_hb})
  63. n_hb = sess.run(update_hb, feed_dict={
  64. X: batch, rbm_w: o_w, rbm_vb: o_vb, rbm_hb: o_hb})
  65. o_w = n_w
  66. o_vb = n_vb
  67. o_hb = n_hb
  68. if start % 10000 == 0:
  69. print sess.run(
  70. err_sum, feed_dict={X: trX, rbm_w: n_w, rbm_vb: n_vb, rbm_hb: n_hb})
  71. image = Image.fromarray(
  72. tile_raster_images(
  73. X=n_w.T,
  74. img_shape=(28, 28),
  75. tile_shape=(25, 20),
  76. tile_spacing=(1, 1)
  77. )
  78. )
  79. image.save("rbm_%d.png" % (start / 10000))
Add Comment
Please, Sign In to add comment