Advertisement
Guest User

Untitled

a guest
Jun 25th, 2013
34
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.10 KB | None | 0 0
  1. #include "jadiag.h"
  2.  
  3. #include <shogun/base/init.h>
  4. #include <shogun/lib/common.h>
  5. #include <shogun/io/SGIO.h>
  6.  
  7. #include <shogun/mathematics/Math.h>
  8. #include <shogun/mathematics/eigen3.h>
  9.  
  10. using namespace Eigen;
  11.  
  12. typedef Matrix< float64_t, Dynamic, 1, ColMajor > EVector;
  13. typedef Matrix< float64_t, Dynamic, Dynamic, ColMajor > EMatrix;
  14.  
  15. using namespace shogun;
  16.  
  17. SGMatrix<float64_t> jadiag(SGNDArray<float64_t> &M, SGMatrix<float64_t> *W_est0,
  18. double eps, int itermax)
  19. {
  20. int L = M.dims[2];
  21. int d = M.dims[0];
  22. double decr = 1;
  23. double logdet = log(5.184e17);
  24.  
  25. SGMatrix<float64_t> W_est;
  26. if(W_est0 != NULL)
  27. {
  28. W_est = W_est0->clone();
  29. }
  30. else
  31. {
  32. W_est = SGMatrix<float64_t>::create_identity_matrix(d,1);
  33. }
  34.  
  35. double result = 0;
  36.  
  37. EVector w(L);
  38. w.setOnes();
  39.  
  40. EMatrix ctot(d, d*L);
  41. for(int i = 0; i < L; i++)
  42. {
  43. Eigen::Map<EMatrix> Ci(M.get_matrix(i),d,d);
  44. ctot.block(0,i*d,d,d) = Ci;
  45. }
  46.  
  47. double crit;
  48. int iter = 0;
  49.  
  50. while(decr > eps && iter < itermax)
  51. {
  52. if(logdet == 0)// is NA
  53. {
  54. SG_SERROR("log det does not exist\n")
  55. break;
  56. }
  57.  
  58. jadiagw(ctot.data(),
  59. w.data(),
  60. &d, &L,
  61. W_est.matrix,
  62. &logdet,
  63. &decr,
  64. &result);
  65.  
  66. iter = iter + 1;
  67. }
  68.  
  69. if(iter == itermax)
  70. {
  71. SG_SERROR("Convergence not reached\n")
  72. }
  73.  
  74. return W_est;
  75. }
  76.  
  77. void jadiagw(double c[], double w[], int *ptn, int *ptm, double a[],
  78. double *logdet, double *decr, double *result)
  79. {
  80. int n = *ptn;
  81. int m = *ptm;
  82. int i1,j1;
  83. int n2 = n*n, mn2 = m*n2,
  84. i, ic, ii, ij, j, jc, jj, k, k0;
  85. double sumweigh, p2, q1, p, q,
  86. alpha, beta, gamma, a12, a21, tiny, det;
  87. register double tmp1, tmp2, tmp, weigh;
  88.  
  89. for (sumweigh = 0, i = 0; i < m; i++)
  90. {
  91. sumweigh += w[i];
  92. }
  93.  
  94. det = 1;
  95. *decr = 0;
  96.  
  97. for (i = 1, ic = n; i < n ; i++, ic += n)
  98. {
  99. for (j = jc = 0; j < i; j++, jc += n)
  100. {
  101. ii = i + ic;
  102. jj = j + jc;
  103. ij = i + jc;
  104.  
  105. for (q1 = p2 = p = q = 0, k0 = k = 0; k0 < m; k0++, k += n2)
  106. {
  107. weigh = w[k0];
  108. tmp1 = c[ii+k];
  109. tmp2 = c[jj+k];
  110. tmp = c[ij+k];
  111. p += weigh*tmp/tmp1;
  112. q += weigh*tmp/tmp2;
  113. q1 += weigh*tmp1/tmp2;
  114. p2 += weigh*tmp2/tmp1;
  115. }
  116.  
  117. q1 /= sumweigh;
  118. p2 /= sumweigh;
  119. p /= sumweigh;
  120. q /= sumweigh;
  121. beta = 1 - p2*q1;// p1 = q2 = 1
  122.  
  123. if (q1 <= p2)// the same as q1*q2 <= p1*p2
  124. {
  125. alpha = p2*q - p;// q2 = 1
  126.  
  127. if (fabs(alpha) - beta < 10e-20)// beta <= 0 always
  128. {
  129. beta = -1;
  130. gamma = p/p2;
  131. }
  132. else
  133. {
  134. gamma = - (p*beta + alpha)/p2;// p1 = 1
  135. }
  136.  
  137. *decr += sumweigh*(p*p - alpha*alpha/beta)/p2;
  138. }
  139. else
  140. {
  141. gamma = p*q1 - q;// p1 = 1
  142.  
  143. if (fabs(gamma) - beta < 10e-20)// beta <= 0 always
  144. {
  145. beta = -1;
  146. alpha = q/q1;
  147. }
  148. else
  149. {
  150. alpha = - (q*beta + gamma)/q1;// q2 = 1
  151. }
  152.  
  153. *decr += sumweigh*(q*q - gamma*gamma/beta)/q1;
  154. }
  155.  
  156. tmp = (beta - sqrt(beta*beta - 4*alpha*gamma))/2;
  157. a12 = gamma/tmp;
  158. a21 = alpha/tmp;
  159.  
  160. for (k = 0; k < mn2; k += n2)
  161. {
  162. for (ii = i, jj = j; ii < ij; ii += n, jj += n)
  163. {
  164. tmp = c[ii+k];
  165. c[ii+k] += a12*c[jj+k];
  166. c[jj+k] += a21*tmp;
  167. }// at exit ii = ij = i + jc
  168.  
  169. tmp = c[i+ic+k];
  170. c[i+ic+k] += a12*(2*c[ij+k] + a12*c[jj+k]);
  171. c[jj+k] += a21*c[ij+k];
  172. c[ij+k] += a21*tmp;// = element of index j,i
  173.  
  174. for (; ii < ic; ii += n, jj++)
  175. {
  176. tmp = c[ii+k];
  177. c[ii+k] += a12*c[jj+k];
  178. c[jj+k] += a21*tmp;
  179. }
  180.  
  181. for (; ++ii, ++jj < jc+n; )
  182. {
  183. tmp = c[ii+k];
  184. c[ii+k] += a12*c[jj+k];
  185. c[jj+k] += a21*tmp;
  186. }
  187.  
  188. }
  189.  
  190. for (k = 0; k < n2; k += n)
  191. {
  192. tmp = a[i+k];
  193. a[i+k] += a12*a[j+k];
  194. a[j+k] += a21*tmp;
  195. }
  196.  
  197. det *= 1 - a12*a21;// compute determinant
  198. }
  199. }
  200.  
  201. *logdet += 2*sumweigh*log(det);
  202.  
  203. for (tmp = 0, k0 = k = 0; k0 < m; k0++, k += n2)
  204. {
  205. for (det = 1, ii = 0; ii < n2; ii += n+1)
  206. {
  207. det *= c[ii+k];
  208. tmp += w[k0]*log(det);
  209. }
  210. }
  211.  
  212. *result = tmp - *logdet;
  213.  
  214. return;
  215. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement