Advertisement
milanmetal

[SciKit] Color Quantization using K-Means

Mar 6th, 2018
315
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.75 KB | None | 0 0
  1. # Authors: Robert Layton <robertlayton@gmail.com>
  2. #          Olivier Grisel <olivier.grisel@ensta.org>
  3. #          Mathieu Blondel <mathieu@mblondel.org>
  4. #   http://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html
  5. #
  6. # License: BSD 3 clause
  7. from PIL import Image
  8. print(__doc__)
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from sklearn.cluster import KMeans
  12. from sklearn.metrics import pairwise_distances_argmin
  13. from sklearn.datasets import load_sample_image
  14. from sklearn.utils import shuffle
  15. from time import time
  16.  
  17. n_colors = 4
  18.  
  19. # Load the Summer Palace photo
  20. # image = load_sample_image("image.jpg")
  21.  
  22. image = Image.open("lena.jpg")
  23.  
  24. # Convert to floats instead of the default 8 bits integer coding. Dividing by
  25. # 255 is important so that plt.imshow behaves works well on float data (need to
  26. # be in the range [0-1])
  27. image = np.array(image, dtype=np.float64) / 255
  28.  
  29. # Load Image and transform to a 2D numpy array.
  30. w, h, d = original_shape = tuple(image.shape)
  31. assert d == 3
  32. image_array = np.reshape(image, (w * h, d))
  33.  
  34. print("Fitting model on a small sub-sample of the data")
  35. t0 = time()
  36. image_array_sample = shuffle(image_array, random_state=0)[:1000]
  37. kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(image_array_sample)
  38. print("done in %0.3fs." % (time() - t0))
  39.  
  40. # Get labels for all points
  41. print("Predicting color indices on the full image (k-means)")
  42. t0 = time()
  43. labels = kmeans.predict(image_array)
  44. print("done in %0.3fs." % (time() - t0))
  45.  
  46.  
  47. codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1]
  48. print("Predicting color indices on the full image (random)")
  49. t0 = time()
  50. labels_random = pairwise_distances_argmin(codebook_random,
  51.                                           image_array,
  52.                                           axis=0)
  53. print("done in %0.3fs." % (time() - t0))
  54.  
  55.  
  56. def recreate_image(codebook, labels, w, h):
  57.     """Recreate the (compressed) image from the code book & labels"""
  58.     d = codebook.shape[1]
  59.     image = np.zeros((w, h, d))
  60.     label_idx = 0
  61.     for i in range(w):
  62.         for j in range(h):
  63.             image[i][j] = codebook[labels[label_idx]]
  64.             label_idx += 1
  65.     return image
  66.  
  67.  
  68. # Display all results, alongside original image
  69. plt.figure(1)
  70. plt.clf()
  71. ax = plt.axes([0, 0, 1, 1])
  72. plt.axis('off')
  73. plt.title('Original image (96,615 colors)')
  74. plt.imshow(image)
  75.  
  76. plt.figure(2)
  77. plt.clf()
  78. ax = plt.axes([0, 0, 1, 1])
  79. plt.axis('off')
  80. plt.title('Quantized image (64 colors, K-Means)')
  81. plt.imshow(recreate_image(kmeans.cluster_centers_, labels, w, h))
  82.  
  83. plt.figure(3)
  84. plt.clf()
  85. ax = plt.axes([0, 0, 1, 1])
  86. plt.axis('off')
  87. plt.title('Quantized image (64 colors, Random)')
  88. plt.imshow(recreate_image(codebook_random, labels_random, w, h))
  89. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement