Guest User

Untitled

a guest
Jan 21st, 2018
57
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.82 KB | None | 0 0
  1. import pandas as pd
  2. import matplotlib.pyplot as plt
  3. from sklearn import cluster
  4. from scipy.spatial import distance
  5. import sklearn.datasets
  6. from sklearn.preprocessing import StandardScaler
  7. import numpy as np
  8.  
  9. def compute_bic(kmeans,X):
  10. """
  11. Computes the BIC metric for a given clusters
  12.  
  13. Parameters:
  14. -----------------------------------------
  15. kmeans: List of clustering object from scikit learn
  16.  
  17. X : multidimension np array of data points
  18.  
  19. Returns:
  20. -----------------------------------------
  21. BIC value
  22. """
  23. # assign centers and labels
  24. centers = [kmeans.cluster_centers_]
  25. labels = kmeans.labels_
  26. #number of clusters
  27. m = kmeans.n_clusters
  28. # size of the clusters
  29. n = np.bincount(labels)
  30. #size of data set
  31. N, d = X.shape
  32.  
  33. #compute variance for all clusters beforehand
  34. cl_var = (1.0 / (N - m) / d) * sum([sum(distance.cdist(X[np.where(labels == i)], [centers[0][i]],
  35. 'euclidean')**2) for i in range(m)])
  36.  
  37. const_term = 0.5 * m * np.log(N) * (d+1)
  38.  
  39. BIC = np.sum([n[i] * np.log(n[i]) -
  40. n[i] * np.log(N) -
  41. ((n[i] * d) / 2) * np.log(2*np.pi*cl_var) -
  42. ((n[i] - 1) * d/ 2) for i in range(m)]) - const_term
  43.  
  44. return(BIC)
  45.  
  46. path = 'C:/Users/Lionel/Downloads'
  47. file = 'Wholesale customers data.csv'
  48. data = pd.read_csv(path + '/'+file)
  49. X = np.array(data.iloc[:,2 :])
  50. #Xs = StandardScaler().fit_transform(X)
  51. ks = range(1,100)
  52.  
  53. # run 100 times kmeans and save each result in the KMeans object
  54. KMeans = [cluster.KMeans(n_clusters = i, init="k-means++").fit(X) for i in ks]
  55.  
  56. # now run for each cluster the BIC computation
  57. BIC = [compute_bic(kmeansi,X) for kmeansi in KMeans]
  58.  
  59. kopt = BIC.index(max(BIC)) + 1
  60.  
  61. print (BIC)
  62. print( 'kopt = ' + str(kopt) )
  63.  
  64. plt.plot(ks,BIC,'r-o')
  65. plt.title("BIC vs number of clusters")
  66. plt.xlabel("# clusters")
  67. plt.ylabel("# BIC")
  68. plt.show()
Add Comment
Please, Sign In to add comment