Advertisement
Guest User

Untitled

a guest
Oct 15th, 2019
113
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.10 KB | None | 0 0
  1. #include <algorithm>
  2. #include <cmath>
  3. #include <functional>
  4. #include <iostream>
  5. #include <vector>
  6.  
  7. // ガウス過程回帰
  8. // 注意: あんまりテストしてない
  9. class GP {
  10. static constexpr double EPS = 1e-10;
  11. using vec = std::vector<double>;
  12. using mat = std::vector<vec>;
  13.  
  14. mat data_x;
  15. vec data_y;
  16. double data_mu;
  17. double data_sgm2;
  18. mat invK;
  19. std::function<double(const vec&, const vec&)> kernel;
  20.  
  21. // x^T A yを計算
  22. double product(const vec& x, const mat& A, const vec& y) {
  23. double ret = 0;
  24. int N = A.size();
  25. for (int i = 0; i < N; i++) {
  26. double sum = 0;
  27. for (int j = 0; j < N; j++) sum += A[i][j] * y[j];
  28. ret += x[i] * sum;
  29. }
  30. return ret;
  31. }
  32.  
  33. // Aの逆行列を計算
  34. mat inverse(mat A) {
  35. int N = A.size();
  36. // LU分解
  37. std::vector<int> p(N), ip(N);
  38. for (int i = 0; i < N; i++) p[i] = i;
  39. for (int i = 0; i < N; i++) {
  40. int pivot = i;
  41. for (int j = i + 1; j < N; j++) {
  42. if (fabs(A[j][i]) > fabs(A[pivot][i])) {
  43. pivot = j;
  44. }
  45. }
  46. std::swap(A[pivot], A[i]);
  47. std::swap(p[pivot], p[i]);
  48. if (fabs(A[i][i]) < EPS) // detA=0
  49. return mat();
  50.  
  51. for (int j = i + 1; j < N; j++) {
  52. A[j][i] /= A[i][i];
  53. for (int k = i + 1; k < N; k++) {
  54. A[j][k] -= A[i][k] * A[j][i];
  55. }
  56. }
  57. }
  58. for (int i = 0; i < N; i++) ip[p[i]] = i;
  59. // 逆行列
  60. mat B(N, vec(N, 0)), C(N, vec(N, 0)), ret(N, vec(N));
  61. for (int i = 0; i < N; i++) B[i][i] = 1.0 / A[i][i];
  62. for (int i = 1; i < N; i++) {
  63. for (int j = 0, k = i; k < N; j++, k++) {
  64. for (int l = j + 1; l <= k; l++) B[j][k] -= A[j][l] * B[l][k];
  65. B[j][k] /= A[j][j];
  66. for (int l = j; l < k; l++)
  67. B[k][j] -=
  68. (k <= l ? 1.0 : A[k][l]) * (l <= j ? 1.0 : B[l][j]);
  69. }
  70. }
  71. for (int i = 0; i < N; i++) {
  72. for (int j = 0; j < N; j++) {
  73. for (int k = 0; k < N; k++) {
  74. double u = (i > k ? 0 : B[i][k]);
  75. double l = (j > k ? 0 : (j == k ? 1.0 : B[k][j]));
  76. C[i][j] += u * l;
  77. }
  78. }
  79. }
  80. for (int i = 0; i < N; i++) {
  81. for (int j = 0; j < N; j++) {
  82. ret[j][i] = C[j][ip[i]];
  83. }
  84. }
  85. return ret;
  86. }
  87.  
  88. public:
  89. // カーネルの設定
  90. void setKernel(const std::function<double(const vec&, const vec&)>& f) {
  91. kernel = f;
  92. }
  93. // データ(y, x)と、yの正規ノイズの分散sgm2の設定
  94. void setData(const vec& y, const mat& x, double sgm2) {
  95. int N = y.size();
  96. // 平均
  97. double mu = 0;
  98. for (int i = 0; i < N; i++) mu += y[i];
  99. mu /= N;
  100. // 平均を引いたもの
  101. vec yy(N);
  102. for (int i = 0; i < N; i++) yy[i] = y[i] - mu;
  103. // 行列Kと逆行列
  104. mat K(N, vec(N, 0));
  105. for (int i = 0; i < N; i++) {
  106. for (int j = 0; j < N; j++) {
  107. K[i][j] = kernel(x[i], x[j]) + (i == j ? sgm2 : 0);
  108. }
  109. }
  110.  
  111. data_x = x;
  112. data_y = yy;
  113. data_mu = mu;
  114. data_sgm2 = sgm2;
  115. invK = inverse(K); // O(N^3)
  116. }
  117. // xのときの予測値 Normal(mu, sgm2)
  118. std::pair<double, double> predict(const vec& x) {
  119. int N = data_y.size();
  120. vec kk(N);
  121. for (int i = 0; i < N; i++) kk[i] = kernel(data_x[i], x);
  122. double kkk = kernel(x, x) + data_sgm2;
  123. return std::make_pair(data_mu + product(kk, invK, data_y),
  124. std::max(0.0, kkk - product(kk, invK, kk)));
  125. }
  126. };
  127.  
  128. int main() {
  129. GP gp;
  130.  
  131. // ガウスカーネルをセット
  132. gp.setKernel(
  133. [](const std::vector<double>& x, const std::vector<double>& y) {
  134. // パラメータは適当
  135. double theta1 = 1.0;
  136. double theta2 = 1.0;
  137. double sum = 0;
  138. for (int i = 0; i < x.size(); i++)
  139. sum += (x[i] - y[i]) * (x[i] - y[i]);
  140. return theta1 * exp(-sum / theta2);
  141. });
  142.  
  143. // データ
  144. std::vector<double> y;
  145. std::vector<std::vector<double>> x;
  146. x.push_back({1.6});
  147. y.push_back(1.9);
  148.  
  149. x.push_back({4.0});
  150. y.push_back(1.1);
  151.  
  152. x.push_back({2.1});
  153. y.push_back(1.85);
  154.  
  155. x.push_back({0.5});
  156. y.push_back(0.8);
  157.  
  158. x.push_back({2.2});
  159. y.push_back(2.2);
  160.  
  161. // データのセット
  162. gp.setData(y, x, 0.05);
  163. // 予測
  164. for (double i = 0; i <= 5; i += 0.05) {
  165. std::vector<double> tmp{i};
  166. std::pair<double, double> res = gp.predict(tmp);
  167. double mu = res.first;
  168. double sigma = sqrt(res.second);
  169. // 期待値
  170. std::cout << i << "\t" << mu << "\t";
  171. // 2σ区間(におさまる確率=約95.45%)
  172. std::cout << mu - 2 * sigma << "\t" << mu + 2 * sigma << std::endl;
  173. }
  174.  
  175. return 0;
  176. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement