Guest User

Untitled

a guest
Sep 21st, 2018
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.83 KB | None | 0 0
  1. package com.example.javagmm;
  2.  
  3. import java.io.FileInputStream;
  4. import java.io.FileOutputStream;
  5. import java.io.IOException;
  6. import java.io.ObjectInputStream;
  7. import java.io.ObjectOutputStream;
  8. import java.lang.ref.SoftReference;
  9. import java.util.Random;
  10.  
  11.  
  12.  
  13. /**
  14. * <b>Gaussian Mixture Model</b>
  15. *
  16. * <p>Description</p>
  17. * This class implements in combination with <code>GaussianComponent</code> a
  18. * gaussian mixture model. To create a gaussian mixture model you have to
  19. * specify means, covariances and the weight for each of the gaussian components.
  20. * <code>GaussianMixture</code> and GaussianComponent do not only model a GMM,
  21. * but also support training GMMs using the EM algorithm.<br>
  22. * <br>
  23. * One noteable aspect regarding the implementation is the fact that if
  24. * the covariance matrix of any of the components of this GMM gets singular
  25. * during the training process a <code>CovarianceSingularityException</code> is
  26. * thrown. The <code>CovarianceSingularityException</code> contains a reduced
  27. * <code>PointList</code>. All points belonging to the singular component have
  28. * been removed. So after the reduction one can try to rerun the training
  29. * algorithem with the reduced <code>PointList</code>.<br>
  30. * <br>
  31. * Another aspect of the design of this class was influenced by the limited
  32. * memory on real world computers. To improve performace a buffer to store some
  33. * estimations is used. This buffer is <i>static</i> to reduce garbage
  34. * collection time and all training processes are synchronized on this buffer.
  35. * Consequently one can only train one GMM instance at a time.<br>
  36. * <br>
  37. * <b>New in version 1.1:</b><br>
  38. * - The cholesky decomposition is used to speed up computations.
  39. *
  40. * @see comirva.audio.util.gmm.GaussianComponent
  41. * @author Klaus Seyerlehner
  42. * @version 1.1
  43. */
  44. public final class GaussianMixture implements java.io.Serializable
  45. {
  46. /**
  47. *
  48. */
  49. private static final long serialVersionUID = -3450643365494865923L;
  50.  
  51. private int dimension = 0; //dimension of the gmm
  52. private GaussianComponent[] components = new GaussianComponent[0]; //gaussian components
  53.  
  54. private static double[][] p_ij = new double[1][1]; //hard reference to the buffer of current estimate
  55. private static SoftReference<double[][]> p_ij_SoftRef = new SoftReference<double[][]>(p_ij); //soft reference to the buffer of current estimates //defines the maximum number of training iterations
  56. private static Random rnd = new Random(); //a random number generator
  57.  
  58.  
  59. /**
  60. * This constructor creates a GMM and checks the parameters for plausibility.
  61. * The weights, means and covariances of every component are passed as arrays
  62. * to the constructor. The i-th component therefore is completely defined by
  63. * the i-th entries within these arrays.
  64. *
  65. * @param componentWeights double[] specifies the components weights
  66. * @param means Matrix[] specifies the components mean vectors
  67. * @param covariances Matrix[] specifies the components covariance matrices
  68. *
  69. * @throws IllegalArgumentException if any invalid parameter settings are
  70. * detected while checking them
  71. */
  72. public GaussianMixture(double[] componentWeights, Matrix[] means, Matrix[] covariances) throws IllegalArgumentException
  73. {
  74. //check if all parameters are valid
  75. if(componentWeights.length != means.length || means.length != covariances.length || componentWeights.length < 1)
  76. throw new IllegalArgumentException("all arrays must have the same length with size greater than 0;");
  77.  
  78. //create component array
  79. components = new GaussianComponent[componentWeights.length];
  80.  
  81. //check and create the components
  82. double sum = 0;
  83. for(int i = 0; i < components.length; i++)
  84. {
  85. if(means[i] == null || covariances[i] == null)
  86. throw new IllegalArgumentException("all mean and covarince matrices must not be null values;");
  87.  
  88. sum += componentWeights[i];
  89.  
  90. components[i] = new GaussianComponent(componentWeights[i], means[i], covariances[i]);
  91. }
  92.  
  93. //check if the component weights are set correctly
  94. if( sum < 0.99 || sum > 1.01)
  95. throw new IllegalArgumentException("the sum over all component weights must be in the interval [0.99, 1.01];");
  96.  
  97. //set dimension
  98. this.dimension = components[0].getDimension();
  99.  
  100. //check if all the components have the same dimensions
  101. for(int i = 0; i < components.length; i++)
  102. if(components[i].getDimension() != dimension)
  103. throw new IllegalArgumentException("the dimensions of all components must be the same;");
  104.  
  105. }
  106.  
  107. /**
  108. * This constructor creates a GMM and checks the components for compatibility.
  109. * The components themselfs have been checked during their construction.
  110. *
  111. * @param components GaussianComponent[] an array of gaussian components
  112. *
  113. * @throws IllegalArgumentException if the passed components are not
  114. * compatible
  115. */
  116. public GaussianMixture(GaussianComponent[] components) throws IllegalArgumentException
  117. {
  118. if(components == null)
  119. throw new IllegalArgumentException("the component array must not be null;");
  120.  
  121. //check the components
  122. double sum = 0;
  123. for(int i = 0; i < components.length; i++)
  124. {
  125. if(components[i] == null)
  126. throw new IllegalArgumentException("all components in the array must not be null;");
  127.  
  128. sum += components[i].getComponentWeight();
  129. }
  130.  
  131. //check if the component weights are set correctly
  132. if( sum < 0.99 || sum > 1.01)
  133. throw new IllegalArgumentException("the sum over all component weights must be in the interval [0.99, 1.01];");
  134.  
  135. this.components = components;
  136. this.dimension = components[0].getDimension();
  137.  
  138. //check if all the components have the same dimensions
  139. for(int i = 0; i < components.length; i++)
  140. if(components[i].getDimension() != dimension)
  141. throw new IllegalArgumentException("the dimensions of all components must be the same;");
  142. }
  143.  
  144.  
  145.  
  146. /**
  147. * Returns the log likelihood of the points stored in the pointlist under the
  148. * assumption the these points where sample from this GMM.<br>
  149. * <br>
  150. * [SUM over all j: log (SUM over all i:(p(x_j | C = i) * P(C = i)))]
  151. *
  152. * @param points PointList list of sample points to estimate the log
  153. * likelihood of
  154. * @return double the log likelihood of drawing these samples from this gmm
  155. */
  156. public double getLogLikelihood(PointList points)
  157. {
  158. double p = 0;
  159. for (int j = 0; j < points.size(); j++)
  160. p += Math.log(getProbability((Matrix) points.get(j)));
  161. return p;
  162. }
  163.  
  164.  
  165. /**
  166. * Returns the probability of a single sample point under the assumption that
  167. * it was draw from the distribution represented by this GMM.<br>
  168. * <br>
  169. * [SUM over all i:(p(x | C = i) * P(C = i))]
  170. *
  171. * @param x Matrix a sample point
  172. * @return double the probability of the given sample
  173. */
  174. public double getProbability(Matrix x)
  175. {
  176. double p = 0;
  177.  
  178. for(int i = 0; i < components.length; i++)
  179. p += components[i].getWeightedSampleProbability(x);
  180.  
  181. return p;
  182. }
  183.  
  184.  
  185. /**
  186. * Returns the number of dimensions of the GMM.
  187. *
  188. * @return int number of dimmensions
  189. */
  190. public int getDimension()
  191. {
  192. return dimension;
  193. }
  194.  
  195.  
  196. /**
  197. * Prints some information about this gaussian component.
  198. * This is for debugging purpose only.
  199. */
  200. public void print()
  201. {
  202. for(int i = 0; i < components.length; i++)
  203. {
  204. System.out.println("Component " + i + ":");
  205. components[i].print();
  206. }
  207. }
  208.  
  209.  
  210. /**
  211. * For testing purpose only.
  212. *
  213. * @param numberOfComponent int the number of the component
  214. * @return Matrix the mean vector
  215. */
  216. public Matrix getMean(int numberOfComponent)
  217. {
  218. return components[numberOfComponent].getMean();
  219. }
  220.  
  221.  
  222. /**
  223. * This method returns a reference to a buffer for storing estimates of the
  224. * sample points. The buffer will be reused if possible or reallocated, if
  225. * it is too small or if the garbage collector already captured the buffer.
  226. *
  227. * @param nrComponents int the number of components of the gmm to allocate the
  228. * buffer for
  229. * @param nrSamplePoints int the number of sample points of the gmm to
  230. * allocate the buffer for
  231. */
  232. protected static void getBuffer(int nrComponents, int nrSamplePoints)
  233. {
  234. //get the buffer from the soft ref => now hard ref
  235. p_ij = p_ij_SoftRef.get();
  236.  
  237. if(p_ij == null)
  238. {
  239. //reallocate since gc collected the buffer
  240. p_ij = new double[nrComponents][2*nrSamplePoints];
  241. p_ij_SoftRef = new SoftReference<double[][]>(p_ij);
  242. }
  243.  
  244. //check if buffer is too small
  245. if (p_ij[0].length >= nrSamplePoints && p_ij.length >= nrComponents)
  246. return;
  247.  
  248. //to prevent gc runs take double of the current buffer size
  249. if (p_ij[0].length < nrSamplePoints)
  250. nrSamplePoints += nrSamplePoints;
  251.  
  252. //reallocate since buffer was too small
  253. p_ij = new double[nrComponents][nrSamplePoints];
  254. p_ij_SoftRef = new SoftReference<double[][]>(p_ij);
  255.  
  256. //run gc to collect old buffer
  257. System.gc();
  258. }
  259.  
  260.  
  261. /**
  262. * Reads a GaussianMixture object by deserializing it from disk
  263. * @author Florian Schulze
  264. * @param path Path of the serialized GMM file
  265. * @return The deserialized GMM
  266. */
  267. public static GaussianMixture readGMM(String path) {
  268. GaussianMixture gmm = null;
  269. ObjectInputStream ois;
  270. try {
  271. ois = new ObjectInputStream(new FileInputStream(path));
  272. gmm = (GaussianMixture) ois.readObject();
  273. ois.close();
  274. } catch (IOException e) {
  275. e.printStackTrace();
  276. } catch (ClassNotFoundException e) {
  277. e.printStackTrace();
  278. }
  279. return gmm;
  280. }
  281.  
  282.  
  283. /**
  284. * Stores a GaussianMixture object on the hard disk by serializing it
  285. * @param path Where to store the file
  286. */
  287. public void writeGMM(String path) {
  288. try {
  289. ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path));
  290. oos.writeObject(this);
  291. oos.flush();
  292. oos.close();
  293. } catch (IOException e) {
  294. e.printStackTrace();
  295. }
  296. }
  297. }
Add Comment
Please, Sign In to add comment