SHARE
TWEET

Buggy optimal transport code

a guest May 7th, 2012 293 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. function  [u dt cur energy] = match_pdfs_2d_haker( pdf1, pdf2)
  2.  
  3. % normalize pdfs
  4. cumsum1 = cumIntegrate(pdf1(:));
  5. cumsum2 = cumIntegrate(pdf2(:));
  6. pdf1 = pdf1/cumsum1(end);
  7. pdf2 = pdf2/cumsum2(end);
  8.  
  9. [M N] = size(pdf1);
  10.  
  11. % find an initial map, by doing 1D matchings in x and then in y.
  12. proj1 = zeros(1,N);
  13. proj2 = zeros(1,N);
  14. for i=1:N
  15.     cumsum1 = cumIntegrate(pdf1(:,i));
  16.     cumsum2 = cumIntegrate(pdf2(:,i));
  17.     proj1(i) = cumsum1(end);
  18.     proj2(i) = cumsum2(end);
  19. end;
  20.  
  21. a = find_warp(proj2, proj1);
  22. ga = gradientAccurate(a);
  23.  
  24. [X,Y] = meshgrid(1:N, 1:M);
  25. interpPDF2 = interp2(X,Y,pdf2, a, 1:M, 'spline',0.);
  26. b=zeros(M,N);
  27. for i=1:N
  28.     b(:,i) = find_warp( interpPDF2(:, i).*ga(i), pdf1(:,i));  
  29. end;
  30.  
  31. %u^0 = (a,b)
  32. u=reshape([repmat(a', [M 1]) b], M, N,2);
  33.  
  34.  
  35. %%%%% debug : compute errors %%%%%%%%%%
  36. ta = interp1(1:N, proj2, a, 'spline',0.).*abs(ga);
  37. fprintf('error 1D:%f\n',norm(proj1-ta')/norm(proj1));
  38.  
  39. [dbdx dbdy] = gradientAccurate(b);
  40. interpResult = interp2(X,Y,pdf2, u(:,:,1), u(:,:,2), 'spline',0.);
  41. img=interpResult.*repmat(ga, 1, N).*dbdy;
  42. fprintf('error 2D:%f\n',norm(img(:)-pdf1(:))/norm(pdf1(:)));
  43. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  44.  
  45. %let's a constant number of iterations
  46. for t=1:3000
  47.      
  48.     %%%%%%%% compute grad^orthog Laplace^{-1} div(u^orthog) %%%%%%%%%
  49.     %%%%%%%% where orthog stands for a 90 degrees rotations,%%%%%%%%%
  50.     %%%%%%%% and Laplace^-1(g) solves Laplacian f = g       %%%%%%%%%
  51.  
  52.     [dudx dudy] = gradientAccurate(u);
  53.     divorthog = -dudy(:,:,1)+dudx(:,:,2); %div(u^orthog)
  54.     f=reshape(poicalc(divorthog(:),1,1,M,N), M, N); %Laplace^{-1} div(u^orthog)
  55.     [dfdx dfdy] = gradientAccurate(f); %grad Laplace^{-1} div(u^orthog)
  56.    
  57.     % 1./pdf1 * grad^orthog Laplace^{-1} div(u^orthog)
  58.     update1 = repmat(1./pdf1, [1 1 2]).*reshape([-dfdy dfdx], size(u));            
  59.  
  60.     %%%%%%%% upwind for Du %%%%%%%%%%%%%%%%%
  61.     [dudx dudy] = gradientAccurate(u);
  62.     dudxm = dudx;
  63.     dudym = dudy;
  64.     dudxp = dudx;
  65.     dudyp = dudy;
  66.     dudxm(:, 3:end-2, :) = (3*u(:, 3:end-2, :) - 4*u(:, 2:end-3, :) + u(:, 1:end-4, :))*0.5;
  67.     dudxp(:, 3:end-2, :) = (-3*u(:, 3:end-2, :) + 4*u(:, 4:end-1, :) - u(:, 5:end, :))*0.5;
  68.     dudym(3:end-2, :, :) = (3*u(3:end-2, :, :) - 4*u(2:end-3, :, :) + u(1:end-4, :, :))*0.5;
  69.     dudyp(3:end-2, :, :) = (-3*u(3:end-2, :, :) + 4*u(4:end-1, :, :) - u(5:end, :, :))*0.5;
  70.      
  71.     %Du = abs(dudx(:,:,1).*dudy(:,:,2)-dudx(:,:,2).*dudy(:,:,1)); %centered
  72.     Dupp = abs(dudxp(:,:,1).*dudyp(:,:,2)-dudxp(:,:,2).*dudyp(:,:,1));
  73.     Dupm = abs(dudxp(:,:,1).*dudym(:,:,2)-dudxp(:,:,2).*dudym(:,:,1));
  74.     Dump = abs(dudxm(:,:,1).*dudyp(:,:,2)-dudxm(:,:,2).*dudyp(:,:,1));
  75.     Dumm = abs(dudxm(:,:,1).*dudym(:,:,2)-dudxm(:,:,2).*dudym(:,:,1));
  76.        
  77.     Du =  (update1(:, :,1)>0).*(update1(:, :,2)>0).*Dupp ...
  78.         + (update1(:, :,1)>0).*(update1(:, :,2)<0).*Dupm ...
  79.         + (update1(:, :,1)<0).*(update1(:, :,2)>0).*Dump ...
  80.         + (update1(:, :,1)<0).*(update1(:, :,2)<0).*Dumm;
  81.    
  82.     dt(t) = 0.9*min(1./abs(update1(:))); %dt according to stability conditions
  83.        
  84.     %%%%% debug image : should remain constant
  85.      interpResult = interp2(X,Y,pdf2, u(:,:,1), u(:,:,2), 'spline', 0.);
  86.      img=interpResult.*Du;  
  87.      figure(1);
  88.      image(img*N*N*65); colormap('gray');
  89.      
  90.     uxmid = u(:,:,1)-repmat(1:N, M, 1); %u-identity
  91.     uymid = u(:,:,2)-repmat(transpose(1:M), 1, N);        
  92.     e = (uxmid.^2+uymid.^2).*pdf1;
  93.     energy(t) = sum(e(:));        
  94.     cur(t) = sum(sum(curl(u(:,:,1),u(:,:,2))));
  95.     fprintf('iter:%u\tdt:%e\tcurl:%f\tenergy:%f\n', t, dt(t), cur(t), energy(t));    
  96.     %%%%%%%%%%%%%%    
  97.        
  98.  
  99.     % update : du/dt = 1./pdf1 * det(Jac(u)) * update1
  100.     u = u+dt(t).*repmat(Du, [1 1 2]).*update1;
  101.    
  102. end;
  103.  
  104.  
  105.  
  106. function y = find_warp(pdf1, pdf2) %1D matching
  107. cdf1 = cumIntegrate(pdf1);
  108. cdf2 = cumIntegrate(pdf2);
  109. cdf1 = cdf1/cdf1(end);
  110. cdf2 = cdf2/cdf2(end);
  111.  
  112. N = length(cdf1);
  113. y = zeros(N, 1);
  114.  
  115. % cursor = 2;
  116. % for i=1:N
  117. %     desiredF = cdf2(i);
  118. %    
  119. %     while (cursor<N && cdf1(cursor)<desiredF)
  120. %      cursor = cursor+1;
  121. %     end;    
  122. %        
  123. %     alpha = min(1, max(0,(desiredF - cdf1(cursor-1))/(cdf1(cursor)-cdf1(cursor-1))));
  124. %     y(i) = ((cursor-1)*(1-alpha) + alpha*cursor );
  125. % end;
  126.  
  127. % the one below is shorter and might be more accurate
  128.    [ucdf1, uxx] = unique(cdf1, 'first');
  129.    y = transpose(interp1(ucdf1,uxx,cdf2,'spline',N));
  130.    
  131.    
  132. function y = cumIntegrate(f)  % generic function to approximate the cumulative integral
  133.  y= cumtrapz(f); % or y=cumsum, or y=intgrad1 with the library
  134.  
  135. function varargout=gradientAccurate(f)
  136. if (length(f)==length(f(:)))
  137.     [varargout{1}] = gradient2(f);  %or my 5th order gradient2(f) code
  138. else
  139.     [varargout{1} varargout{2}] = gradient2(f);  %or my 5th order gradient2(f) code
  140. end
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top