Ackerven

STL

Dec 1st, 2021
645
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. // ConsoleApplication1.cpp : 定义控制台应用程序的入口点。
  2. //
  3.  
  4. #include<iostream>
  5. #include<algorithm>
  6. #include<vector>
  7. #include<numeric>
  8.  
  9. using namespace std;
  10.  
  11. //函数对象
  12. class SquareError {
  13. public:
  14.     float a, b;
  15.     SquareError(float aa, float bb) :a(aa), b(bb) {}
  16.     float operator()(float x, float y) {
  17.         return (a * x + b - y) * (a * x + b - y);
  18.     }
  19. };
  20.  
  21. //x, y表示观测数据,a,b表示直线参数,error表示拟合误差, error = sum((ax+b - y)^2)
  22. void least_squre(vector<float>& x, vector<float>& y, float& a, float& b, float& error) {
  23.     vector<float> re(x.size());
  24.     //使用transform有两种方法
  25.     //transform(容器1开始的位置,容器1结束的位置,容器2开始的位置,容器3开始的位置,自定义对象 )
  26.     // 容器二和容器三不需要给结束位置,会根据容器1的开始和结束计算
  27.     // 自定义对象可以是匿名函数,也可以是函数对象
  28.    
  29.     // 第一种:匿名函数
  30.     //transform(x.begin(), x.end(), y.begin(), re.begin(), [=](float x, float y) {return (a * x + b - y) * (a * x + b - y); });
  31.     auto f = [=](float x, float y) {return (a * x + b - y) * (a * x + b - y); };
  32.     transform(x.begin(), x.end(), y.begin(), re.begin(), f);
  33.  
  34.     //第二种:函数对象
  35.     SquareError sse(a, b);
  36.     transform(x.begin(), x.end(), y.begin(), re.begin(), sse);
  37.     error = accumulate(re.begin(), re.end(), 0.0);
  38. }
  39.  
  40. //计算u
  41. float calcu(vector<float>& x) {
  42.     return accumulate(x.begin(), x.end(), 0.0) / x.size();
  43. }
  44.  
  45. //计算a
  46. float calca(vector<float>& x, vector<float>& y, float ux, float uy) {
  47.     float a1 = (inner_product(x.begin(), x.end(), y.begin(), 1.0) / x.size()) - ux * uy;
  48.     float a2 = (inner_product(x.begin(), x.end(), x.begin(), 1.0) / x.size()) - ux * ux;
  49.     float a = a1 / a2;
  50.     return a;
  51. }
  52.  
  53. //计算b
  54. float calcb(float a, float ux, float uy) {
  55.     return uy - a * ux;
  56. }
  57.  
  58. int main()
  59. {
  60.     float x1[] = { 6.19f, 2.51f, 7.29f, 7.01f, 5.7f, 2.66f, 3.98f,2.5f, 9.1f, 4.2f };
  61.     float y1[] = { 5.25f, 2.83f, 6.41f, 6.71f, 5.1f, 4.23f, 5.05f, 1.98f, 10.5f, 6.3f };
  62.     float x2[] = { 208.0f, 152.0f, 113.0f, 227.0f, 137.0f, 238.0f, 178.0f, 104.0f, 191.0f, 130.0f };
  63.     float y2[] = { 21.6f, 15.5f, 10.4f, 31.0f, 13.0f, 32.4f, 19.0f, 10.4f, 19.0f,11.8f };
  64.  
  65.     //将数组转换为vector  (其实不转也没关系)
  66.     vector<float> vx1(x1, x1 + sizeof(x1) / sizeof(x1[0]));
  67.     vector<float> vy1(y1, y1 + sizeof(y1) / sizeof(y1[0]));
  68.     vector<float> vx2(x2, x2 + sizeof(x2) / sizeof(x2[0]));
  69.     vector<float> vy2(y2, y2 + sizeof(y2) / sizeof(y2[0]));
  70.  
  71.     //调用拟合函数  
  72.     float error = 0.0;
  73.     float ux = calcu(vx1);
  74.     float uy = calcu(vy1);
  75.     float a = calca(vx1, vy1, ux, uy);
  76.     float b = calcb(a, ux, uy);
  77.     least_squre(vx1, vy1, a, b, error);
  78.  
  79.     //输出拟合结果
  80.     cout << error << endl;
  81.  
  82.     error = 0.0;
  83.     ux = calcu(vx2);
  84.     uy = calcu(vy2);
  85.     a = calca(vx2, vy2, ux, uy);
  86.     b = calcb(a, ux, uy);
  87.     least_squre(vx2, vy2, a, b, error);
  88.  
  89.     cout << error << endl;
  90.  
  91.     return 0;
  92.  
  93. }
  94.  
  95.  
RAW Paste Data