Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.64 KB | None | 0 0
  1. import os
  2. import time
  3. import numpy as np
  4.  
  5. from abc import abstractmethod
  6.  
  7.  
  8. class ClusterTemplate(object):
  9.  
  10. def __init__(self, path):
  11. self.path = path
  12. if not os.path.exists(self.path):
  13. os.makedirs(self.path, exist_ok=True)
  14. self.centroids = None
  15. self.clustering_algo = None
  16.  
  17. @abstractmethod
  18. def init_cluster_algo(self, num_clusters):
  19. raise NotImplementedError("Abstract method 'init_cluster_algo' not implemented")
  20.  
  21. def train(self, embeddings):
  22. self.clustering_algo.fit(embeddings)
  23.  
  24. def should_normalize(self):
  25. return True
  26.  
  27. def normalize(self, embeddings, num_clusters):
  28. # Extract the assigned cluster labels
  29. labels = self.clustering_algo.labels_
  30.  
  31. # Generate centroids using the features and assigned cluster labels
  32. data = np.empty((0, features.shape[1]), 'float32')
  33. for i in range(num_clusters):
  34. row = np.dot(labels == i, embeddings) / np.sum(labels == i)
  35. data = np.vstack((data, row))
  36.  
  37. # Normalize
  38. tdata = data.transpose()
  39. self.centroids = (tdata / np.sqrt(np.sum(tdata * tdata, axis=0))).transpose()
  40.  
  41. def save(self, cluster_name):
  42. np.save(os.path.join(self.path, cluster_name), self.centroids)
  43.  
  44. # Final method that no sub class must override. Should be invoked directly from the client
  45. def cluster(self, features, cluster_name, niter=20, num_clusters=100):
  46.  
  47. self.init_cluster_algo(num_clusters)
  48.  
  49. self.train(embeddings)
  50.  
  51. if self.should_normalize():
  52. self.normalize(embeddings, num_clusters)
  53.  
  54. self.save(cluster_name)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement