SHARE
TWEET

Untitled

a guest Apr 22nd, 2019 61 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1.  
  2. # coding: utf-8
  3.  
  4. # # Laborator 4
  5.  
  6. # In[1]:
  7.  
  8.  
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11.  
  12.  
  13. # In[2]:
  14.  
  15.  
  16. dataPath = "data/"
  17. #load train images
  18. train_images = np.loadtxt(dataPath + "train_images.txt")
  19. train_labels = np.loadtxt(dataPath + "train_labels.txt",'int8')
  20. print(train_images.shape)
  21. print(train_images.ndim)
  22. print(type(train_images[0,0]))
  23. print(train_images.size)
  24. print(train_images.nbytes)
  25.  
  26.  
  27.  
  28.  
  29.  
  30. # In[3]:
  31.  
  32.  
  33. #plot the first 100 training images with their labels in a 10 x 10 subplot
  34. nbImages = 10
  35. plt.figure(figsize=(5,5))
  36. for i in range(nbImages**2):
  37.     plt.subplot(nbImages,nbImages,i+1)
  38.     plt.axis('off')
  39.     plt.imshow(np.reshape(train_images[i,:],(28,28)),cmap = "gray")
  40. plt.show()
  41. labels_nbImages = train_labels[:nbImages**2]
  42. print(np.reshape(labels_nbImages,(nbImages,nbImages)))
  43.  
  44.  
  45. # In[4]:
  46.  
  47.  
  48. #load test images
  49. test_images = np.loadtxt(dataPath + "test_images.txt")
  50. test_labels = np.loadtxt(dataPath + "test_labels.txt",'int8')
  51. #plot the first 100 testing images with their labels in a 10 x 10 subplot
  52. nbImages = 10
  53. plt.figure(figsize=(5,5))
  54. for i in range(nbImages**2):
  55.     plt.subplot(nbImages,nbImages,i+1)
  56.     plt.axis('off')
  57.     plt.imshow(np.reshape(test_images[i,:],(28,28)),cmap = "gray")
  58. plt.show()
  59. labels_nbImages = test_labels[:nbImages**2]
  60. print(np.reshape(labels_nbImages,(nbImages,nbImages)))
  61.  
  62.  
  63. # In[ ]:
  64.  
  65. img = test_images[0]
  66. distances = np.sqrt( (train_images - img)**2, ).sum(axis=1)
  67. print(distances.shape)
  68. indices = distances.argsort
  69. print(indices[0])
  70.  
  71.  
  72. print(distances[804])
  73. print(distances.min())
  74. print(min(distances))
  75.  
  76. print(train_labels[804])
  77.  
  78.  
  79. plt.show("Imagine:")
  80. plt.show()
  81.  
  82. #do 1-NN, 3-NN, 5-NN, 7 -NN for the first test image
  83. #plot the neighbors
  84.  
  85.  
  86. a = np.array[0,5,7,7,5,1,5]
  87. b = np.bincount(a)
  88. print(b)
  89.  
  90.  
  91. # In[ ]:
  92.  
  93.  
  94. #define class Knn_classifier
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top