Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import time
- import numpy as np
- from abc import abstractmethod
- class ClusterTemplate(object):
- def __init__(self, path):
- self.path = path
- if not os.path.exists(self.path):
- os.makedirs(self.path, exist_ok=True)
- self.centroids = None
- self.clustering_algo = None
- @abstractmethod
- def init_cluster_algo(self, num_clusters):
- raise NotImplementedError("Abstract method 'init_cluster_algo' not implemented")
- def train(self, embeddings):
- self.clustering_algo.fit(embeddings)
- def should_normalize(self):
- return True
- def normalize(self, embeddings, num_clusters):
- # Extract the assigned cluster labels
- labels = self.clustering_algo.labels_
- # Generate centroids using the features and assigned cluster labels
- data = np.empty((0, features.shape[1]), 'float32')
- for i in range(num_clusters):
- row = np.dot(labels == i, embeddings) / np.sum(labels == i)
- data = np.vstack((data, row))
- # Normalize
- tdata = data.transpose()
- self.centroids = (tdata / np.sqrt(np.sum(tdata * tdata, axis=0))).transpose()
- def save(self, cluster_name):
- np.save(os.path.join(self.path, cluster_name), self.centroids)
- # Final method that no sub class must override. Should be invoked directly from the client
- def cluster(self, features, cluster_name, niter=20, num_clusters=100):
- self.init_cluster_algo(num_clusters)
- self.train(embeddings)
- if self.should_normalize():
- self.normalize(embeddings, num_clusters)
- self.save(cluster_name)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement