Advertisement
Guest User

Untitled

a guest
Nov 24th, 2017
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.41 KB | None | 0 0
  1. import tensorflow as tf
  2.  
  3. from vae_carl.helpers import batch_index_groups, dtype
  4. from vae_carl import mnist_data
  5. import vae_carl.vae as vae
  6.  
  7. train_total_data, _, _, _, test_data, test_labels = mnist_data.prepare_MNIST_data()
  8.  
  9. train_size = 10000
  10. IMAGE_SIZE_MNIST = 28
  11. num_hidden = 500
  12. dim_img = IMAGE_SIZE_MNIST ** 2
  13. dim_z = 2
  14. learn_rate = 1e-3
  15. batch_size = min(128, train_size)
  16. num_epochs = 10
  17.  
  18. y_input = tf.placeholder(dtype, shape=[None, dim_img], name='input_img')
  19. y_output_true = tf.placeholder(dtype, shape=[None, dim_img], name='target_img')
  20.  
  21. # dropout
  22. keep_prob = tf.placeholder(dtype, name='keep_prob')
  23.  
  24. # network architecture
  25. ae = vae.autoencoder(
  26.     y_input=y_input,
  27.     y_output_true=y_output_true,
  28.     dim_img=dim_img,
  29.     dim_z=dim_z,
  30.     num_hidden=num_hidden,
  31.     keep_prob=keep_prob
  32. )
  33.  
  34. # optimization
  35. train_step = tf.train.AdamOptimizer(learn_rate).minimize(ae.loss)
  36.  
  37. y_train = train_total_data[:train_size, :-mnist_data.NUM_LABELS]
  38. y_train_labels = train_total_data[:train_size, -mnist_data.NUM_LABELS:]
  39.  
  40. print("Num data points", train_size)
  41. print("Num epochs", num_epochs)
  42.  
  43. with tf.Session() as session:
  44.  
  45.     session.run(tf.global_variables_initializer())
  46.     session.graph.finalize()
  47.  
  48.     for epoch in range(num_epochs):
  49.         for i, batch_indices in enumerate(batch_index_groups(batch_size=batch_size, num_samples=train_size)):
  50.  
  51.             batch_xs_input = y_train[batch_indices, :]
  52.  
  53.             _, tot_loss, loss_likelihood, loss_divergence, learnt_rep = session.run(
  54.                 (
  55.                     train_step,
  56.                     ae.loss,
  57.                     ae.neg_marginal_likelihood,
  58.                     ae.kl_divergence,
  59.                     ae.y_output
  60.                 ),
  61.                 feed_dict={
  62.                     y_input: batch_xs_input,
  63.                     y_output_true: batch_xs_input,
  64.                     keep_prob: 0.9
  65.                 }
  66.             )
  67.             print("SHAPE: {}".format(learnt_rep.shape))
  68.             print("TYPE: {}".format(type(learnt_rep)))
  69.             fig = plt.figure()
  70.             ax = fig.add_subplot(131)
  71.             ax.imshow(learnt_rep, cmap='gray')
  72.             plt.show()
  73.  
  74.         print(
  75.             "epoch %d: L_tot %03.2f L_likelihood %03.2f L_divergence %03.2f" % (
  76.                 epoch,
  77.                 tot_loss,
  78.                 loss_likelihood,
  79.                 loss_divergence
  80.             )
  81.         )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement