SHARE
TWEET

Untitled

a guest Aug 23rd, 2019 67 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import matplotlib.pyplot as plt
  2. from keras.datasets import mnist
  3. import numpy as np
  4. from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
  5. from keras.models import Model
  6. from keras import backend as K
  7.  
  8. (x_train, _), (x_test, _) = mnist.load_data()
  9.  
  10. x_train = x_train.astype('float32') / 255.
  11. x_test = x_test.astype('float32') / 255.
  12. x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
  13. x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format
  14.  
  15. noise_factor = 0.5
  16. x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
  17. x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)
  18.  
  19. x_train_noisy = np.clip(x_train_noisy, 0., 1.)
  20. x_test_noisy = np.clip(x_test_noisy, 0., 1.)
  21.  
  22. n = 10
  23. plt.figure(figsize=(20, 2))
  24. for i in range(1, n+1):
  25.     ax = plt.subplot(1, n, i)
  26.     plt.imshow(x_test_noisy[i].reshape(28, 28))
  27.     plt.gray()
  28.     ax.get_xaxis().set_visible(False)
  29.     ax.get_yaxis().set_visible(False)
  30. plt.show()
  31.  
  32.  
  33. input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format
  34.  
  35. # use Conv2D, MaxPooling2D - twice
  36. # use Conv2D, UpSampling2D - twice
  37. x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
  38. x = MaxPooling2D((2, 2), padding='same')(x)
  39. x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
  40. encoded = MaxPooling2D((2, 2), padding='same')(x)
  41.  
  42. # at this point the representation is (7, 7, 32)
  43.  
  44. x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
  45. x = UpSampling2D((2, 2))(x)
  46. x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
  47. x = UpSampling2D((2, 2))(x)
  48. decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
  49.  
  50. autoencoder = Model(input_img, decoded)
  51. autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
  52. autoencoder.summary()
  53.  
  54. autoencoder.fit(x_train_noisy, x_train,
  55.                 epochs=10,
  56.                 batch_size=128,
  57.                 shuffle=True,
  58.                 validation_data=(x_test_noisy, x_test))
  59.  
  60.  
  61. decoded_imgs = autoencoder.predict(x_test)
  62.  
  63. import matplotlib.pyplot as plt
  64.  
  65. n = 10
  66. plt.figure(figsize=(20, 4))
  67. for i in range(1, n+1):
  68.     # display original
  69.     ax = plt.subplot(2, n, i)
  70.     plt.imshow(x_test[i].reshape(28, 28))
  71.     plt.gray()
  72.     ax.get_xaxis().set_visible(False)
  73.     ax.get_yaxis().set_visible(False)
  74.  
  75.     # display reconstruction
  76.     ax = plt.subplot(2, n, i + n)
  77.     plt.imshow(decoded_imgs[i].reshape(28, 28))
  78.     plt.gray()
  79.     ax.get_xaxis().set_visible(False)
  80.     ax.get_yaxis().set_visible(False)
  81. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top