Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2019
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.61 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. from keras.datasets import mnist
  3. import numpy as np
  4. from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
  5. from keras.models import Model
  6. from keras import backend as K
  7.  
  8. (x_train, _), (x_test, _) = mnist.load_data()
  9.  
  10. x_train = x_train.astype('float32') / 255.
  11. x_test = x_test.astype('float32') / 255.
  12. x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) # adapt this if using `channels_first` image data format
  13. x_test = np.reshape(x_test, (len(x_test), 28, 28, 1)) # adapt this if using `channels_first` image data format
  14.  
  15. noise_factor = 0.5
  16. x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
  17. x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)
  18.  
  19. x_train_noisy = np.clip(x_train_noisy, 0., 1.)
  20. x_test_noisy = np.clip(x_test_noisy, 0., 1.)
  21.  
  22. n = 10
  23. plt.figure(figsize=(20, 2))
  24. for i in range(1, n+1):
  25. ax = plt.subplot(1, n, i)
  26. plt.imshow(x_test_noisy[i].reshape(28, 28))
  27. plt.gray()
  28. ax.get_xaxis().set_visible(False)
  29. ax.get_yaxis().set_visible(False)
  30. plt.show()
  31.  
  32.  
  33. input_img = Input(shape=(28, 28, 1)) # adapt this if using `channels_first` image data format
  34.  
  35. # use Conv2D, MaxPooling2D - twice
  36. # use Conv2D, UpSampling2D - twice
  37. x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
  38. x = MaxPooling2D((2, 2), padding='same')(x)
  39. x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
  40. encoded = MaxPooling2D((2, 2), padding='same')(x)
  41.  
  42. # at this point the representation is (7, 7, 32)
  43.  
  44. x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
  45. x = UpSampling2D((2, 2))(x)
  46. x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
  47. x = UpSampling2D((2, 2))(x)
  48. decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
  49.  
  50. autoencoder = Model(input_img, decoded)
  51. autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
  52. autoencoder.summary()
  53.  
  54. autoencoder.fit(x_train_noisy, x_train,
  55. epochs=10,
  56. batch_size=128,
  57. shuffle=True,
  58. validation_data=(x_test_noisy, x_test))
  59.  
  60.  
  61. decoded_imgs = autoencoder.predict(x_test)
  62.  
  63. import matplotlib.pyplot as plt
  64.  
  65. n = 10
  66. plt.figure(figsize=(20, 4))
  67. for i in range(1, n+1):
  68. # display original
  69. ax = plt.subplot(2, n, i)
  70. plt.imshow(x_test[i].reshape(28, 28))
  71. plt.gray()
  72. ax.get_xaxis().set_visible(False)
  73. ax.get_yaxis().set_visible(False)
  74.  
  75. # display reconstruction
  76. ax = plt.subplot(2, n, i + n)
  77. plt.imshow(decoded_imgs[i].reshape(28, 28))
  78. plt.gray()
  79. ax.get_xaxis().set_visible(False)
  80. ax.get_yaxis().set_visible(False)
  81. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement