Emania

Untitled

Dec 13th, 2016
125
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.24 KB | None | 0 0
  1. function [Mean,Std,PG] = EM(x, nr_groups, labels)
  2.  
  3. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  4. % [Mean,Std,P] = EM(x,nr_groups,P_0,max_nr_it);
  5. % EM algo for 1-dimensional Gaussian Mixture Model
  6. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  7. % INPUT:
  8. % x.............data with [nr_pts,nr_dim] = size(x)
  9. % nr_groups.....number of Gaussians
  10. % labels........class labels
  11. %
  12. % OUTPUT:
  13. % Mean..........Mean(i) is the mean of the i-th Gaussian
  14. % Std...........Std(i) is the standard deviation of the i-th Gaussian
  15. % PG............PG(i) is the prior probability of the i-th Gaussian
  16. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  17.  
  18. max_nr_it = 1000;
  19. nr_pts = length(x);
  20.  
  21. if nr_groups >1 % divide RANGE in equal parts
  22. if nargin == 2
  23. P_0 = initialize_em_range_based(x,nr_groups);
  24. elseif nargin == 3
  25. P_0 = initialize_em_range_based(x,nr_groups,labels); % See end of this file
  26. end
  27. else % only one group: trivial assignment
  28. P_0 = ones(nr_pts,1);
  29. end
  30.  
  31. P = P_0;
  32. Mean = zeros(nr_groups,1);
  33. Std = ones(nr_groups,1);
  34.  
  35. green_light = 1;
  36. nr_it = 0;
  37.  
  38. while (green_light == 1) && (nr_it < max_nr_it)
  39. nr_it = nr_it + 1;
  40.  
  41. P_new = zeros(size(P));
  42.  
  43. for k = 1 : nr_groups
  44. PP = P(:,k);
  45. D = x.* PP; % Data weighted with P-matrix
  46. if sum(P(:,k)) ~=0 % there are datapoints assigned to this group
  47. mean_grp = sum(D)/sum(PP);
  48. var_grp = sum(((x - mean_grp).^2).*PP)/sum(PP);
  49. std_grp = sqrt(var_grp);
  50. else
  51. mean_grp = 0;
  52. std_grp = 1;
  53. end
  54.  
  55. F = normpdf(x,mean_grp,std_grp);
  56. Mean(k,:) = mean_grp;
  57. Std(k,:) = std_grp(:)';
  58. P_new(:,k) = F;
  59. end
  60.  
  61. P_old = P;
  62. P = P_new;
  63.  
  64. % Here, you can add your code to fix the labels
  65. if nargin == 3
  66. have = labels ~= 0;
  67. for j = 1:nr_groups
  68. P(have, j) = labels(have) == j;
  69. end
  70. end
  71. % Renormalize
  72. P_sum = sum(P,2); PP_sum = P_sum *ones(1,nr_groups);
  73.  
  74. % Precautions to avoid "divide by zero"
  75. u_zero = find(P_sum < 10^(-200)); %
  76.  
  77. if ~isempty(u_zero)
  78. % create uniform distribution
  79. Q = zeros(nr_pts,nr_groups); Q(u_zero,:) = 1/nr_groups;
  80. N = ones(nr_pts,nr_groups); N(u_zero,:)=0;
  81. PP_sum(u_zero,:) = 1;
  82. P = (P./PP_sum).*N + Q;
  83. else
  84. P = P./(sum(P,2)*ones(1,nr_groups));
  85. end
  86. end
  87.  
  88. %
  89. % PACKAGE RESULTS FOR EXPORT
  90. %===========================
  91.  
  92. % P is matrix of size [nr_pts,nr_gauss] such that each row contains the
  93. % soft assignment of the corresponding point to the different Gaussian components.
  94. % By summing over the rows, we get the proportional contribution of each
  95. % Gaussian.
  96.  
  97. PG = sum(P,1);
  98. PG = PG'/nr_pts;
  99.  
  100. return
  101.  
  102. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  103.  
  104.  
  105. function P_0 = initialize_em_range_based(x,nr_groups,labels)
  106.  
  107. nr_pts = length(x);
  108. P_raw = zeros(nr_pts,nr_groups);
  109.  
  110. if nargin == 2
  111. xrange = range(x);
  112. dx = xrange/(nr_groups-1);
  113. ss = 0.5*xrange/(2*(nr_groups-1));
  114.  
  115. xc = min(x) + dx*(0:nr_groups-1);
  116.  
  117. for j = 1:nr_groups
  118. P_raw(:,j) = normpdf(x,xc(j),ss);
  119. end
  120. elseif nargin == 3
  121. [Means, Stds, PGs] = estGauss(x, nr_groups, labels);
  122. have = find(labels ~= 0);
  123. for i = 1:length(have) % edit
  124. s = 0;
  125. for j = 1:nr_groups
  126. s = s + PGs(j)*normpdf(x(have(i)), Means(j), Stds(j));
  127. end
  128. for j = 1:nr_groups
  129. P_raw(have(i),j) = PGs(j)*normpdf(x(have(i)), Means(j), Stds(j))/s;
  130. end
  131. end
  132.  
  133. % edit
  134. xrange = range(x);
  135. dx = xrange/(nr_groups-1);
  136. ss = 0.5*xrange/(2*(nr_groups-1));
  137.  
  138. xc = min(x) + dx*(0:nr_groups-1);
  139.  
  140. nhave = labels == 0;
  141. for j = 1:nr_groups
  142. P_raw(nhave,j) = normpdf(x(nhave),xc(j),ss);
  143. end
  144. end
  145.  
  146. % make sure each row sums to 1:
  147. P_0 = P_raw./(sum(P_raw,2)*ones(1,nr_groups));
  148. if nargin == 3
  149. for i = 1:nr_groups
  150. indexes = labels == i;
  151. P_0(indexes,:) = zeros(sum(indexes),nr_groups);
  152. P_0(indexes,i) = 1;
  153. end
  154. end
  155. return
Advertisement
Add Comment
Please, Sign In to add comment