Advertisement
Guest User

Untitled

a guest
Jan 17th, 2017
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.21 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import division
  4. import tensorflow as tf
  5. import numpy as np
  6. from tensorflow.examples.tutorials.mnist import input_data
  7. from PIL import Image
  8. from utils import tile_raster_images, scale_to_unit_interval
  9. import math
  10. import matplotlib.pyplot as plt
  11.  
  12.  
  13. def main():
  14. mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)
  15. trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
  16.  
  17. Nv = 784
  18. v_shape = (28, 28)
  19. Nh = 100
  20. h1_shape = (10, 10)
  21.  
  22. gibbs_sampling_steps = 1
  23. alpha = 0.1 # koeficijent učenja
  24.  
  25. g1 = tf.Graph()
  26. with g1.as_default():
  27. X1 = tf.placeholder("float", [None, 784]) # X1 [batch_size x 784]
  28. w1 = weights([Nv, Nh]) # w1 [784 x 100]
  29. vb1 = bias([Nv])
  30. hb1 = bias([Nh])
  31.  
  32. h0_prob = tf.random_normal([tf.shape(X1)[0], Nh])
  33. h0 = sample_prob(h0_prob)
  34. h1 = h0 # h1 [batch_size x 100]
  35.  
  36. for step in range(gibbs_sampling_steps):
  37. v1_prob = tf.sigmoid(tf.matmul(h1, tf.transpose(w1)) + vb1) # v1 [batch_size x 784]
  38. v1 = sample_prob(v1_prob)
  39.  
  40. h1_prob = tf.sigmoid(tf.matmul(v1, w1) + hb1) # h1 [batch_size x 100]
  41. h1 = sample_prob(h1_prob)
  42.  
  43. # pozitivna faza
  44. w1_positive_grad = tf.matmul(tf.transpose(X1), h0) # [784, 100]
  45. # negativna faza
  46. w1_negative_grad = tf.matmul(tf.transpose(v1_prob), h1) # [784, 100]
  47.  
  48. dw1 = (w1_positive_grad - w1_negative_grad) / tf.to_float(tf.shape(X1)[0])
  49.  
  50. # operacije za osvježavanje parametara mreže - one pokreću učenje RBM-a
  51. update_w1 = tf.assign_add(w1, alpha * dw1)
  52. update_vb1 = tf.assign_add(vb1, alpha * tf.reduce_mean(X1 - v1_prob, 0))
  53. update_hb1 = tf.assign_add(hb1, alpha * tf.reduce_mean(h0 - h1, 0))
  54.  
  55. out1 = (update_w1, update_vb1, update_hb1)
  56.  
  57. # rekonstrukcija ulaznog vektora - koristimo vjerojatnost p(v=1)
  58. v1_prob = tf.sigmoid(tf.matmul(h1, tf.transpose(w1)) + vb1)
  59.  
  60. err1 = X1 - v1_prob
  61. err_sum1 = tf.reduce_mean(err1 * err1)
  62.  
  63. initialize1 = tf.initialize_all_variables()
  64.  
  65. batch_size = 100
  66. epochs = 100
  67. n_samples = mnist.train.num_examples
  68.  
  69. total_batch = int(n_samples / batch_size) * epochs
  70.  
  71. with tf.Session(graph=g1) as sess:
  72. sess.run(initialize1)
  73. for i in range(total_batch):
  74. batch, label = mnist.train.next_batch(batch_size)
  75. err, _ = sess.run([err_sum1, out1], feed_dict={X1: batch})
  76.  
  77. if i % (int(total_batch / 10)) == 0:
  78. print i, err
  79.  
  80. w1s = w1.eval()
  81. vb1s = vb1.eval()
  82. hb1s = hb1.eval()
  83. vr, h1s = sess.run([v1_prob, h1], feed_dict={X1: teX[0:2, :]})
  84.  
  85. # vizualizacija težina
  86. draw_weights(w1s, v_shape, Nh)
  87.  
  88. # vizualizacija rekonstrukcije i stanja
  89. draw_reconstructions(teX, vr, h1s, v_shape, h1_shape, Nh)
  90.  
  91.  
  92. def weights(shape):
  93. initial = tf.truncated_normal(shape, stddev=0.1)
  94. return tf.Variable(initial)
  95.  
  96.  
  97. def bias(shape):
  98. initial = tf.zeros(shape, dtype=tf.float32)
  99. return tf.Variable(initial)
  100.  
  101.  
  102. def sample_prob(probs):
  103. """Uzorkovanje vektora x prema vektoru vjerojatnosti p(x=1) = probs"""
  104. return tf.nn.relu(
  105. tf.sign(probs - tf.random_uniform(tf.shape(probs))))
  106.  
  107.  
  108. def draw_weights(W, shape, N, interpolation="bilinear"):
  109. """Vizualizacija težina
  110.  
  111. W -- vektori težina
  112. shape -- tuple dimenzije za 2D prikaz težina - obično dimenzije ulazne slike, npr. (28,28)
  113. N -- broj vektora težina
  114. """
  115. image = Image.fromarray(tile_raster_images(
  116. X=W.T,
  117. img_shape=shape,
  118. tile_shape=(int(math.ceil(N / 20)), 20),
  119. tile_spacing=(1, 1)))
  120. plt.figure(figsize=(10, 14))
  121. plt.imshow(image, interpolation=interpolation)
  122.  
  123.  
  124. def draw_reconstructions(ins, outs, states, shape_in, shape_state, Nh):
  125. """Vizualizacija ulaza i pripadajućih rekonstrkcija i stanja skrivenog sloja
  126. ins -- ualzni vektori
  127. outs -- rekonstruirani vektori
  128. states -- vektori stanja skrivenog sloja
  129. shape_in -- dimezije ulaznih slika npr. (28,28)
  130. shape_state -- dimezije za 2D prikaz stanja (npr. za 100 stanja (10,10)
  131. """
  132. plt.figure(figsize=(8, 12 * 4))
  133. for i in range(20):
  134.  
  135. plt.subplot(20, 4, 4 * i + 1)
  136. plt.imshow(ins[i].reshape(shape_in), vmin=0,
  137. vmax=1, interpolation="nearest")
  138. plt.title("Test input")
  139. plt.subplot(20, 4, 4 * i + 2)
  140. plt.imshow(outs[i][0:784].reshape(shape_in),
  141. vmin=0, vmax=1, interpolation="nearest")
  142. plt.title("Reconstruction")
  143. plt.subplot(20, 4, 4 * i + 3)
  144. plt.imshow(states[i][0:Nh].reshape(shape_state),
  145. vmin=0, vmax=1, interpolation="nearest")
  146. plt.title("States")
  147. plt.tight_layout()
  148.  
  149.  
  150. if __name__ == '__main__':
  151. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement