Advertisement
Guest User

Untitled

a guest
Nov 21st, 2017
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.85 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from scipy.stats import norm
  4.  
  5. from keras.layers import Input, Dense, Lambda, Layer
  6. from keras.models import Model
  7. from keras import backend as K
  8. from keras import metrics
  9. from keras.datasets import mnist
  10.  
  11. observations = 10
  12. batch_size = 1
  13. original_dim = 9+(4*observations)
  14. latent_dim = 2
  15. intermediate_dim = 20
  16. epochs = 50
  17. epsilon_std = 1.0
  18.  
  19.  
  20. x = Input(shape=(original_dim,))
  21. h = Dense(intermediate_dim, activation='relu')(x)
  22. z_mean = Dense(latent_dim)(h)
  23. z_log_var = Dense(latent_dim)(h)
  24.  
  25.  
  26. def sampling(args):
  27.     z_mean, z_log_var = args
  28.     epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
  29.                               stddev=epsilon_std)
  30.     return z_mean + K.exp(z_log_var / 2) * epsilon
  31.  
  32. # note that "output_shape" isn't necessary with the TensorFlow backend
  33. z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
  34.  
  35. # we instantiate these layers separately so as to reuse them later
  36. decoder_h = Dense(intermediate_dim, activation='relu')
  37. decoder_mean = Dense(original_dim, activation='sigmoid')
  38. h_decoded = decoder_h(z)
  39. x_decoded_mean = decoder_mean(h_decoded)
  40.  
  41.  
  42. # Custom loss layer
  43. class CustomVariationalLayer(Layer):
  44.     def __init__(self, **kwargs):
  45.         self.is_placeholder = True
  46.         super(CustomVariationalLayer, self).__init__(**kwargs)
  47.  
  48.     def vae_loss(self, x, x_decoded_mean):
  49.         xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
  50.         kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  51.         return K.mean(xent_loss + kl_loss)
  52.  
  53.     def call(self, inputs):
  54.         x = inputs[0]
  55.         x_decoded_mean = inputs[1]
  56.         loss = self.vae_loss(x, x_decoded_mean)
  57.         self.add_loss(loss, inputs=inputs)
  58.         # We won't actually use the output.
  59.         return x
  60.  
  61. y = CustomVariationalLayer()([x, x_decoded_mean])
  62. vae = Model(x, y)
  63. vae.compile(optimizer='rmsprop', loss=None)
  64.  
  65.  
  66. # train the VAE on MNIST digits
  67. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  68.  
  69. x_train = x_train.astype('float32') / 255.
  70. x_test = x_test.astype('float32') / 255.
  71. x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
  72. x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
  73.  
  74. vae.fit(x_train,
  75.         shuffle=True,
  76.         epochs=epochs,
  77.         batch_size=batch_size,
  78.         validation_data=(x_test, None))
  79.  
  80. # build a model to project inputs on the latent space
  81. encoder = Model(x, z_mean)
  82.  
  83. # display a 2D plot of the digit classes in the latent space
  84. x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
  85. plt.figure(figsize=(6, 6))
  86. plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
  87. plt.colorbar()
  88. plt.show()
  89.  
  90. # build a digit generator that can sample from the learned distribution
  91. decoder_input = Input(shape=(latent_dim,))
  92. _h_decoded = decoder_h(decoder_input)
  93. _x_decoded_mean = decoder_mean(_h_decoded)
  94. generator = Model(decoder_input, _x_decoded_mean)
  95.  
  96. # display a 2D manifold of the digits
  97. n = 15  # figure with 15x15 digits
  98. digit_size = 28
  99. figure = np.zeros((digit_size * n, digit_size * n))
  100. # linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
  101. # to produce values of the latent variables z, since the prior of the latent space is Gaussian
  102. grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
  103. grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
  104.  
  105. for i, yi in enumerate(grid_x):
  106.     for j, xi in enumerate(grid_y):
  107.         z_sample = np.array([[xi, yi]])
  108.         x_decoded = generator.predict(z_sample)
  109.         digit = x_decoded[0].reshape(digit_size, digit_size)
  110.         figure[i * digit_size: (i + 1) * digit_size,
  111.                j * digit_size: (j + 1) * digit_size] = digit
  112.  
  113. plt.figure(figsize=(10, 10))
  114. plt.imshow(figure, cmap='Greys_r')
  115. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement