Advertisement
Guest User

Untitled

a guest
May 6th, 2015
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 1.54 KB | None | 0 0
  1. function [ X_pred ] = svd_bias( X )
  2. m = size(X, 1); %users
  3. n = size(X, 2); %items
  4.  
  5. %define hyperparameters of the model
  6. MAX_ITERS = 100000;
  7. lambda = 0;
  8. f = 1;
  9.  
  10. [rows, cols, vals] = find(X);
  11. R = [rows, cols, vals];
  12. k = 1;
  13. mu = mean(vals);
  14.  
  15. bu = k*rand(m, 1);
  16. bi = k*rand(n, 1);
  17. P = k*rand(m, f); % latent vectors for the users
  18. Q = k*rand(n, f); % latent vectors for the items
  19.  
  20. % bu = [0, 0 ,0]';
  21. % bi = -repmat(mu, 3, 1)-bu;
  22. % disp(bi);
  23. % P = [1 2 3]';
  24. % Q = [1 2 3]';
  25.  
  26. fprintf('Performing SGD...\n');
  27. for iter=1:MAX_ITERS
  28.     if rem(iter, 1000) == 0
  29.         fprintf('iteration %d\n', iter);
  30.     end
  31.     gamma = 0.05;%1/sqrt(iter);
  32.     %i = i-th observed data point
  33.     i = randi(size(R,1));
  34.     ui = R(i,1);
  35.     ii = R(i,2);
  36.     vi = R(i,3);
  37.     Pui = P(ui,:);
  38.     Qii = Q(ii,:);
  39.     eui = vi - mu - bu(ui) - bi(ii) - dot(Pui,Qii);
  40.     bu(ui) = bu(ui) + gamma*(eui - lambda*bu(ui));
  41.     bi(ii) = bi(ii) + gamma*(eui - lambda*bi(ii));
  42.     Q(ii,:)  = Qii  + gamma*(eui*Pui - lambda*Qii);
  43.     P(ui,:)  = Pui  + gamma*(eui*Qii - lambda*Pui);
  44. end
  45.  
  46. % disp(bu');
  47. % disp(bi');
  48. % disp(P');
  49. % disp(Q');
  50.  
  51.  
  52. X_pred = X;
  53. [zrows, zcols] = find(~X);
  54. I = find(~X);
  55. PQ = P*Q';
  56. disp(repmat(mu, length(zrows), 1) + bu(zrows) + bi(zcols));
  57. X_pred(I) = repmat(mu, length(zrows), 1) + bu(zrows) + bi(zcols) + PQ(I);
  58. %predict all values
  59. %X_pred = repmat(mu, size(PQ)) + repmat(bu,1,n) + repmat(bi',m,1) + PQ;
  60. % disp(mu);
  61. % disp(bu);
  62. % disp(bi);
  63. A = [1 2 3; 2 4 6; 3 6 9];
  64. % disp(X_pred(I)-);
  65. disp(sum((X_pred(I)-A(I)) .^2));
  66.  
  67. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement