Advertisement
Guest User

Buggy optimal transport code

a guest
May 7th, 2012
735
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 4.80 KB | None | 0 0
  1. function  [u dt cur energy] = match_pdfs_2d_haker( pdf1, pdf2)
  2.  
  3. % normalize pdfs
  4. cumsum1 = cumIntegrate(pdf1(:));
  5. cumsum2 = cumIntegrate(pdf2(:));
  6. pdf1 = pdf1/cumsum1(end);
  7. pdf2 = pdf2/cumsum2(end);
  8.  
  9. [M N] = size(pdf1);
  10.  
  11. % find an initial map, by doing 1D matchings in x and then in y.
  12. proj1 = zeros(1,N);
  13. proj2 = zeros(1,N);
  14. for i=1:N
  15.     cumsum1 = cumIntegrate(pdf1(:,i));
  16.     cumsum2 = cumIntegrate(pdf2(:,i));
  17.     proj1(i) = cumsum1(end);
  18.     proj2(i) = cumsum2(end);
  19. end;
  20.  
  21. a = find_warp(proj2, proj1);
  22. ga = gradientAccurate(a);
  23.  
  24. [X,Y] = meshgrid(1:N, 1:M);
  25. interpPDF2 = interp2(X,Y,pdf2, a, 1:M, 'spline',0.);
  26. b=zeros(M,N);
  27. for i=1:N
  28.     b(:,i) = find_warp( interpPDF2(:, i).*ga(i), pdf1(:,i));  
  29. end;
  30.  
  31. %u^0 = (a,b)
  32. u=reshape([repmat(a', [M 1]) b], M, N,2);
  33.  
  34.  
  35. %%%%% debug : compute errors %%%%%%%%%%
  36. ta = interp1(1:N, proj2, a, 'spline',0.).*abs(ga);
  37. fprintf('error 1D:%f\n',norm(proj1-ta')/norm(proj1));
  38.  
  39. [dbdx dbdy] = gradientAccurate(b);
  40. interpResult = interp2(X,Y,pdf2, u(:,:,1), u(:,:,2), 'spline',0.);
  41. img=interpResult.*repmat(ga, 1, N).*dbdy;
  42. fprintf('error 2D:%f\n',norm(img(:)-pdf1(:))/norm(pdf1(:)));
  43. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  44.  
  45. %let's a constant number of iterations
  46. for t=1:3000
  47.      
  48.     %%%%%%%% compute grad^orthog Laplace^{-1} div(u^orthog) %%%%%%%%%
  49.     %%%%%%%% where orthog stands for a 90 degrees rotations,%%%%%%%%%
  50.     %%%%%%%% and Laplace^-1(g) solves Laplacian f = g       %%%%%%%%%
  51.  
  52.     [dudx dudy] = gradientAccurate(u);
  53.     divorthog = -dudy(:,:,1)+dudx(:,:,2); %div(u^orthog)
  54.     f=reshape(poicalc(divorthog(:),1,1,M,N), M, N); %Laplace^{-1} div(u^orthog)
  55.     [dfdx dfdy] = gradientAccurate(f); %grad Laplace^{-1} div(u^orthog)
  56.    
  57.     % 1./pdf1 * grad^orthog Laplace^{-1} div(u^orthog)
  58.     update1 = repmat(1./pdf1, [1 1 2]).*reshape([-dfdy dfdx], size(u));            
  59.  
  60.     %%%%%%%% upwind for Du %%%%%%%%%%%%%%%%%
  61.     [dudx dudy] = gradientAccurate(u);
  62.     dudxm = dudx;
  63.     dudym = dudy;
  64.     dudxp = dudx;
  65.     dudyp = dudy;
  66.     dudxm(:, 3:end-2, :) = (3*u(:, 3:end-2, :) - 4*u(:, 2:end-3, :) + u(:, 1:end-4, :))*0.5;
  67.     dudxp(:, 3:end-2, :) = (-3*u(:, 3:end-2, :) + 4*u(:, 4:end-1, :) - u(:, 5:end, :))*0.5;
  68.     dudym(3:end-2, :, :) = (3*u(3:end-2, :, :) - 4*u(2:end-3, :, :) + u(1:end-4, :, :))*0.5;
  69.     dudyp(3:end-2, :, :) = (-3*u(3:end-2, :, :) + 4*u(4:end-1, :, :) - u(5:end, :, :))*0.5;
  70.      
  71.     %Du = abs(dudx(:,:,1).*dudy(:,:,2)-dudx(:,:,2).*dudy(:,:,1)); %centered
  72.     Dupp = abs(dudxp(:,:,1).*dudyp(:,:,2)-dudxp(:,:,2).*dudyp(:,:,1));
  73.     Dupm = abs(dudxp(:,:,1).*dudym(:,:,2)-dudxp(:,:,2).*dudym(:,:,1));
  74.     Dump = abs(dudxm(:,:,1).*dudyp(:,:,2)-dudxm(:,:,2).*dudyp(:,:,1));
  75.     Dumm = abs(dudxm(:,:,1).*dudym(:,:,2)-dudxm(:,:,2).*dudym(:,:,1));
  76.        
  77.     Du =  (update1(:, :,1)>0).*(update1(:, :,2)>0).*Dupp ...
  78.         + (update1(:, :,1)>0).*(update1(:, :,2)<0).*Dupm ...
  79.         + (update1(:, :,1)<0).*(update1(:, :,2)>0).*Dump ...
  80.         + (update1(:, :,1)<0).*(update1(:, :,2)<0).*Dumm;
  81.    
  82.     dt(t) = 0.9*min(1./abs(update1(:))); %dt according to stability conditions
  83.        
  84.     %%%%% debug image : should remain constant
  85.      interpResult = interp2(X,Y,pdf2, u(:,:,1), u(:,:,2), 'spline', 0.);
  86.      img=interpResult.*Du;  
  87.      figure(1);
  88.      image(img*N*N*65); colormap('gray');
  89.      
  90.     uxmid = u(:,:,1)-repmat(1:N, M, 1); %u-identity
  91.     uymid = u(:,:,2)-repmat(transpose(1:M), 1, N);        
  92.     e = (uxmid.^2+uymid.^2).*pdf1;
  93.     energy(t) = sum(e(:));        
  94.     cur(t) = sum(sum(curl(u(:,:,1),u(:,:,2))));
  95.     fprintf('iter:%u\tdt:%e\tcurl:%f\tenergy:%f\n', t, dt(t), cur(t), energy(t));    
  96.     %%%%%%%%%%%%%%    
  97.        
  98.  
  99.     % update : du/dt = 1./pdf1 * det(Jac(u)) * update1
  100.     u = u+dt(t).*repmat(Du, [1 1 2]).*update1;
  101.    
  102. end;
  103.  
  104.  
  105.  
  106. function y = find_warp(pdf1, pdf2) %1D matching
  107. cdf1 = cumIntegrate(pdf1);
  108. cdf2 = cumIntegrate(pdf2);
  109. cdf1 = cdf1/cdf1(end);
  110. cdf2 = cdf2/cdf2(end);
  111.  
  112. N = length(cdf1);
  113. y = zeros(N, 1);
  114.  
  115. % cursor = 2;
  116. % for i=1:N
  117. %     desiredF = cdf2(i);
  118. %    
  119. %     while (cursor<N && cdf1(cursor)<desiredF)
  120. %      cursor = cursor+1;
  121. %     end;    
  122. %        
  123. %     alpha = min(1, max(0,(desiredF - cdf1(cursor-1))/(cdf1(cursor)-cdf1(cursor-1))));
  124. %     y(i) = ((cursor-1)*(1-alpha) + alpha*cursor );
  125. % end;
  126.  
  127. % the one below is shorter and might be more accurate
  128.    [ucdf1, uxx] = unique(cdf1, 'first');
  129.    y = transpose(interp1(ucdf1,uxx,cdf2,'spline',N));
  130.    
  131.    
  132. function y = cumIntegrate(f)  % generic function to approximate the cumulative integral
  133.  y= cumtrapz(f); % or y=cumsum, or y=intgrad1 with the library
  134.  
  135. function varargout=gradientAccurate(f)
  136. if (length(f)==length(f(:)))
  137.     [varargout{1}] = gradient2(f);  %or my 5th order gradient2(f) code
  138. else
  139.     [varargout{1} varargout{2}] = gradient2(f);  %or my 5th order gradient2(f) code
  140. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement