Advertisement
Guest User

Untitled

a guest
Apr 10th, 2020
193
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.35 KB | None | 0 0
  1. class HistogramClustering(skl.base.BaseEstimator, skl.base.TransformerMixin):
  2.     """Template class for HistogramClustering (HC)
  3.    
  4.    Attributes:
  5.        centroids (np.ndarray): Array of centroid distributions p(y|c) with shape (n_clusters, n_bins).
  6.        
  7.    Parameters:
  8.        n_clusters (int): Number of clusters (textures).
  9.        n_bins (int): Number of bins used to discretize the range of pixel values found in input image X.
  10.        window_size (int): Size of the window used to compute the local histograms for each pixel.
  11.                           Should be an odd number larger or equal to 3.
  12.        random_state (int): Random seed.
  13.        estimation (str): Whether to use Maximum a Posteriori ("MAP") or
  14.                          Deterministic Annealing ("DA") estimation.
  15.    """
  16.    
  17.     def __init__(self, n_clusters=10, n_bins=64, window_size=7, random_state=42, estimation="MAP"):
  18.         self.n_clusters = n_clusters
  19.         self.n_bins =n_bins
  20.         self.window_size = window_size
  21.         self.random_state = random_state
  22.         self.estimation = estimation
  23.         # Add more parameters, if necessary.
  24.    
  25.     def fit(self, X):
  26.         """Compute HC for input image X
  27.        
  28.        Compute centroids.        
  29.        
  30.        Args:
  31.            X (np.ndarray): Input array with shape (height, width)
  32.        
  33.        Returns:
  34.            self
  35.        """
  36.  
  37.         np.histogram(X, bins=np.arange(self.n_bins))
  38.  
  39.         # First calculate the bins
  40.        
  41.         normalized_X = X.copy()
  42.  
  43.         normalized_X *= self.n_bins
  44.         # normalized_X = np.ceil(normalized_X)
  45.  
  46.         N = np.prod(X.shape)
  47.  
  48.         n = np.zeros((N, self.n_bins))
  49.  
  50.         indices = [(x, y) for x, y in np.ndindex(X.shape)]
  51.        
  52.         pad = self.window_size // 2
  53.         indices = np.add(indices, pad)
  54.         Xmax = normalized_X.max()
  55.         padded_X = np.pad(normalized_X, pad, 'constant', constant_values=normalized_X.max() * 10)
  56.  
  57.  
  58.         i = 0
  59.         for idx in indices:
  60.             histogram = np.histogram(padded_X[idx[0]-pad:idx[0]+pad, idx[1]-pad:idx[1]+pad], bins=self.n_bins, range=(padded_X.min(), Xmax))
  61.             n[i, ] = histogram[0]
  62.             i += 1
  63.  
  64.         if self.estimation == "MAP":
  65.            
  66.             # These are constant w.r.t the iterative algorithm
  67.             n_x = np.sum(n, axis=1).reshape(-1, 1)
  68.  
  69.             # p_x = n_x / np.sum(n_x)
  70.  
  71.             p_y_x = n / n_x
  72.  
  73.             p_y_c = np.random.multivariate_normal(np.zeros(self.n_bins * self.n_clusters), np.eye(self.n_bins * self.n_clusters)).reshape(self.n_bins, self.n_clusters)
  74.             p_y_c /= np.sum(p_y_c, axis=1).reshape(-1, 1)
  75.  
  76.             c_x = np.random.randint(0, self.n_clusters, size=(N, 1))
  77.             c_x_old = c_x
  78.  
  79.             p_y_c_old = p_y_c
  80.  
  81.             iter = 0
  82.             while True:
  83.                 # estimate p(y|c):w
  84.                 mask = np.zeros((N, self.n_clusters))
  85.                 mask[np.arange(N).reshape(-1, 1), c_x_old] = 1
  86.                 for c in range(self.n_clusters):
  87.                     inner_sum = n_x / np.sum(n_x * mask[:, c].reshape(-1, 1), axis=0, keepdims=True)
  88.                     row = np.sum( inner_sum * p_y_x * mask[:, c].reshape(-1, 1), axis=0)
  89.  
  90.                     p_y_c[:, c] = row
  91.  
  92.                 cluster_assignments = np.zeros((N, self.n_clusters))
  93.                 for c in range(self.n_clusters):
  94.                     # m = 1e16
  95.                     cluster_assignments[:, c] = -np.sum(p_y_x * np.log(p_y_c_old[:, c] + np.finfo(float).eps), axis=1)
  96.  
  97.                 c_x = np.argmin(cluster_assignments, axis=1).reshape(-1, 1)
  98.  
  99.                 diff = np.sum(p_y_c - p_y_c_old)
  100.                 if iter > 10:
  101.                     break
  102.                 c_x_old = c_x
  103.                 p_y_c_old = p_y_c
  104.                 iter += 1
  105.  
  106.                 print('\rFinished iter {}, diff={}'.format(iter, diff), end='', flush=True)
  107.  
  108.  
  109.             self.p_y_x = p_y_x
  110.             self.p_y_c = p_y_c
  111.             self.c_x = c_x
  112.  
  113.             self.centroids = p_y_c
  114.            
  115.             # Code for Maximum a Posteriori estimation
  116.        
  117.         elif self.estimation == "DA":
  118.             raise NotImplementedError()
  119.            
  120.             # Code for Deterministic Annealing estimation
  121.        
  122.         return self
  123.    
  124.     def predict(self, X):
  125.         """Predict cluster assignments for each pixel in image X.
  126.        
  127.        Args:
  128.            X (np.ndarray): Input array with shape (height, width)
  129.            
  130.        Returns:
  131.            C (np.ndarray): Assignment map (height, width)
  132.        """
  133.         check_is_fitted(self, ["centroids"])
  134.        
  135.         # Your code goes here
  136.  
  137.         return C
  138.    
  139.     def generate(self, C):
  140.         """Generate a sample image X from a texture label map C.
  141.        
  142.        The entries of C are integers from the set {1,...,n_clusters}. They represent the texture labels
  143.        of each pixel. Given the texture labels, a sample image X is generated by sampling
  144.        the value of each pixel from the fitted p(y|c).
  145.        
  146.        Args:
  147.            C (np.ndarray): Input array with shape (height, width)
  148.            
  149.        Returns:
  150.            X (np.ndarray): Sample image (height, width)
  151.        """
  152.         check_is_fitted(self, ["centroids"])
  153.        
  154.         # Your code goes here
  155.        
  156.         return X
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement