Advertisement
Guest User

Untitled

a guest
Apr 24th, 2019
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.99 KB | None | 0 0
  1. import scipy
  2. from sklearn.datasets import load_iris
  3.  
  4. from sklearn.model_selection import train_test_split
  5. from sklearn.cluster import KMeans
  6. from sklearn.metrics import accuracy_score
  7.  
  8. def find_permutation(n_clusters, real_labels, labels):
  9.     permutation=[]
  10.     for i in range(n_clusters):
  11.         idx = labels == i
  12.         # Choose the most common label among data points in the cluster
  13.         new_label=scipy.stats.mode(real_labels[idx])[0][0]  
  14.         permutation.append(new_label)
  15.     return permutation
  16.    
  17. def plant_clustering():
  18.  
  19.     X , y  = load_iris(return_X_y=True)
  20.  
  21.     cluss = KMeans( 3 , random_state  = 0 )
  22.     cluss.fit( X )
  23.  
  24.     y_pred = cluss.predict( X )
  25.     #print( y_pred )
  26.     #print( y )
  27.  
  28.     perm = find_permutation( 3 , y , y_pred )
  29.     #print(perm )
  30.     acc = accuracy_score( y , [ perm[label] for label in cluss.labels_])
  31.  
  32.     #print(acc)
  33.     return acc
  34.    
  35. def main():
  36.     print(plant_clustering())
  37.  
  38. if __name__ == "__main__":
  39.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement