Advertisement
Guest User

plot_digits_agglomeration.py

a guest
Nov 20th, 2017
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.65 KB | None | 0 0
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3.  
  4. """
  5. =========================================================
  6. Feature agglomeration
  7. =========================================================
  8.  
  9. These images how similar features are merged together using
  10. feature agglomeration.
  11. """
  12. print(__doc__)
  13.  
  14. # Code source: Gaël Varoquaux
  15. # Modified for documentation by Jaques Grobler
  16. # License: BSD 3 clause
  17.  
  18. import numpy as np
  19. import matplotlib.pyplot as plt
  20.  
  21. from sklearn import datasets, cluster
  22. from sklearn.feature_extraction.image import grid_to_graph
  23.  
  24. digits = datasets.load_digits()
  25. images = digits.images
  26. X = np.reshape(images, (len(images), -1))
  27. connectivity = grid_to_graph(*images[0].shape)
  28.  
  29. agglo = cluster.FeatureAgglomeration(connectivity=connectivity,
  30.                                      n_clusters=32)
  31.  
  32. agglo.fit(X)
  33. X_reduced = agglo.transform(X)
  34.  
  35. X_restored = agglo.inverse_transform(X_reduced)
  36. images_restored = np.reshape(X_restored, images.shape)
  37. plt.figure(1, figsize=(4, 3.5))
  38. plt.clf()
  39. plt.subplots_adjust(left=.01, right=.99, bottom=.01, top=.91)
  40. for i in range(4):
  41.     plt.subplot(3, 4, i + 1)
  42.     plt.imshow(images[i], cmap=plt.cm.gray, vmax=16, interpolation='nearest')
  43.     plt.xticks(())
  44.     plt.yticks(())
  45.     if i == 1:
  46.         plt.title('Original data')
  47.     plt.subplot(3, 4, 4 + i + 1)
  48.     plt.imshow(images_restored[i], cmap=plt.cm.gray, vmax=16,
  49.                interpolation='nearest')
  50.     if i == 1:
  51.         plt.title('Agglomerated data')
  52.     plt.xticks(())
  53.     plt.yticks(())
  54.  
  55. plt.subplot(3, 4, 10)
  56. plt.imshow(np.reshape(agglo.labels_, images[0].shape),
  57.            interpolation='nearest', cmap=plt.cm.spectral)
  58. plt.xticks(())
  59. plt.yticks(())
  60. plt.title('Labels')
  61. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement