Advertisement
MagicWinnie

Solution

Mar 17th, 2023 (edited)
444
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 14.18 KB | None | 0 0
  1. #include <iostream>
  2. #include <iomanip>
  3. #include <cmath>
  4. #define USE_GNUPLOT 1
  5. #define GNUPLOT_NAME "gnuplot -persist"
  6.  
  7. using namespace std;
  8.  
  9. class Matrix
  10. {
  11. protected:
  12.     int rows = 0, cols = 0;
  13.     double **matrix;
  14.  
  15. public:
  16.     Matrix(int rows, int cols)
  17.     {
  18.         this->rows = rows;
  19.         this->cols = cols;
  20.  
  21.         matrix = new double *[rows];
  22.         for (int i = 0; i < rows; i++)
  23.             matrix[i] = new double[cols];
  24.  
  25.         for (int i = 0; i < rows; i++)
  26.             for (int j = 0; j < cols; j++)
  27.                 matrix[i][j] = 0.0;
  28.     }
  29.     ~Matrix()
  30.     {
  31.         for (int i = 0; i < rows; i++)
  32.             delete[] matrix[i];
  33.         delete[] matrix;
  34.     }
  35.     int get_rows() const
  36.     {
  37.         return rows;
  38.     }
  39.     int get_cols() const
  40.     {
  41.         return cols;
  42.     }
  43.     friend istream &operator>>(istream &in, Matrix &m)
  44.     {
  45.         for (int i = 0; i < m.rows; i++)
  46.             for (int j = 0; j < m.cols; j++)
  47.                 in >> m[i][j];
  48.         return in;
  49.     }
  50.     friend ostream &operator<<(ostream &out, Matrix &m)
  51.     {
  52.         for (int i = 0; i < m.rows; i++)
  53.         {
  54.             for (int j = 0; j < m.cols; j++)
  55.                 out << ((abs(m[i][j]) <= 1e-6 and signbit(m[i][j])) ? 0.0 : m[i][j]) << ' ';
  56.             out << '\n';
  57.         }
  58.         return out;
  59.     }
  60.     double *operator[](int i)
  61.     {
  62.         return matrix[i];
  63.     }
  64.     Matrix &operator=(Matrix &old)
  65.     {
  66.         if (this == &old)
  67.             return *this;
  68.  
  69.         for (int i = 0; i < rows; i++)
  70.             delete[] matrix[i];
  71.         delete[] matrix;
  72.  
  73.         rows = old.rows;
  74.         cols = old.cols;
  75.  
  76.         matrix = new double *[rows];
  77.         for (int i = 0; i < rows; i++)
  78.             matrix[i] = new double[cols];
  79.  
  80.         for (int i = 0; i < rows; i++)
  81.             for (int j = 0; j < cols; j++)
  82.                 matrix[i][j] = old[i][j];
  83.  
  84.         return *this;
  85.     }
  86.     Matrix *operator+(Matrix &B)
  87.     {
  88.         if (rows != B.rows or cols != B.cols)
  89.             throw runtime_error("Error: the dimensional problem occurred");
  90.         Matrix *C = new Matrix(rows, cols);
  91.         for (int i = 0; i < rows; i++)
  92.             for (int j = 0; j < cols; j++)
  93.                 C->matrix[i][j] = matrix[i][j] + B[i][j];
  94.         return C;
  95.     }
  96.     Matrix *operator-(Matrix &B)
  97.     {
  98.         if (rows != B.rows or cols != B.cols)
  99.             throw runtime_error("Error: the dimensional problem occurred");
  100.         Matrix *C = new Matrix(rows, cols);
  101.         for (int i = 0; i < rows; i++)
  102.             for (int j = 0; j < cols; j++)
  103.                 C->matrix[i][j] = matrix[i][j] - B[i][j];
  104.         return C;
  105.     }
  106.     Matrix *operator*(Matrix &B)
  107.     {
  108.         if (cols != B.rows)
  109.             throw runtime_error("Error: the dimensional problem occurred");
  110.         Matrix *C = new Matrix(rows, B.cols);
  111.         for (int i = 0; i < rows; i++)
  112.             for (int j = 0; j < B.cols; j++)
  113.                 for (int k = 0; k < cols; k++)
  114.                     C->matrix[i][j] += matrix[i][k] * B[k][j];
  115.         return C;
  116.     }
  117.     Matrix *T()
  118.     {
  119.         Matrix *B = new Matrix(cols, rows);
  120.         for (int i = 0; i < rows; i++)
  121.             for (int j = 0; j < cols; j++)
  122.                 B->matrix[j][i] = matrix[i][j];
  123.         return B;
  124.     }
  125. };
  126.  
  127. class SquareMatrix : public Matrix
  128. {
  129. protected:
  130.     friend SquareMatrix *inverse_worker(SquareMatrix *A, bool debug_info);
  131.     friend double determinant_worker(SquareMatrix *A, bool debug_info);
  132.  
  133. public:
  134.     SquareMatrix(int size) : Matrix(size, size) {}
  135.     int get_size()
  136.     {
  137.         return rows;
  138.     }
  139.     SquareMatrix *operator+(SquareMatrix &B)
  140.     {
  141.         Matrix *upcast_A = this;
  142.         Matrix *upcast_B = &B;
  143.         Matrix *upcast_C = (*upcast_A) + (*upcast_B);
  144.         SquareMatrix *C = (SquareMatrix *)(upcast_C);
  145.         return C;
  146.     }
  147.     SquareMatrix *operator-(SquareMatrix &B)
  148.     {
  149.         Matrix *upcast_A = this;
  150.         Matrix *upcast_B = &B;
  151.         Matrix *upcast_C = (*upcast_A) - (*upcast_B);
  152.         SquareMatrix *C = (SquareMatrix *)(upcast_C);
  153.         return C;
  154.     }
  155.     SquareMatrix *operator*(SquareMatrix &B)
  156.     {
  157.         Matrix *upcast_A = this;
  158.         Matrix *upcast_B = &B;
  159.         Matrix *upcast_C = (*upcast_A) * (*upcast_B);
  160.         SquareMatrix *C = (SquareMatrix *)(upcast_C);
  161.         return C;
  162.     }
  163.     SquareMatrix *T()
  164.     {
  165.         Matrix *upcast_A = this;
  166.         Matrix *upcast_C = (*upcast_A).T();
  167.         SquareMatrix *C = (SquareMatrix *)(upcast_C);
  168.         return C;
  169.     }
  170.     double determinant(bool debug_info)
  171.     {
  172.         return determinant_worker(this, debug_info);
  173.     }
  174.     SquareMatrix *inverse(bool debug_info)
  175.     {
  176.         return inverse_worker(this, debug_info);
  177.     }
  178. };
  179.  
  180. class IdentityMatrix : public SquareMatrix
  181. {
  182. public:
  183.     IdentityMatrix(int size) : SquareMatrix(size)
  184.     {
  185.         for (int i = 0; i < size; i++)
  186.             matrix[i][i] = 1.0;
  187.     }
  188. };
  189.  
  190. class EliminationMatrix : public IdentityMatrix
  191. {
  192. public:
  193.     EliminationMatrix(int size, int i, int j, double val) : IdentityMatrix(size)
  194.     {
  195.         matrix[i][j] = val * -1.0;
  196.     }
  197. };
  198.  
  199. class PermutationMatrix : public IdentityMatrix
  200. {
  201. public:
  202.     PermutationMatrix(int size, int i, int j) : IdentityMatrix(size)
  203.     {
  204.         swap(matrix[i], matrix[j]);
  205.     }
  206. };
  207.  
  208. class ColumnVector : public Matrix
  209. {
  210. public:
  211.     ColumnVector(int size) : Matrix(size, 1) {}
  212.     double *operator[](int i)
  213.     {
  214.         return &matrix[i][0];
  215.     }
  216. };
  217.  
  218. template <typename L, typename R>
  219. class AugmentedMatrix
  220. {
  221. protected:
  222.     L *matrixLeft;
  223.     R *matrixRight;
  224.  
  225. public:
  226.     AugmentedMatrix(L *A, R *B)
  227.     {
  228.         if (A->get_rows() != B->get_rows())
  229.             throw runtime_error("Error: the dimensional problem occurred");
  230.         matrixLeft = A;
  231.         matrixRight = B;
  232.     }
  233.     ~AugmentedMatrix()
  234.     {
  235.         delete[] matrixLeft;
  236.         delete[] matrixRight;
  237.     }
  238.     L *getLeft()
  239.     {
  240.         return matrixLeft;
  241.     }
  242.     R *getRight()
  243.     {
  244.         return matrixRight;
  245.     }
  246.     friend ostream &operator<<(ostream &out, AugmentedMatrix &m)
  247.     {
  248.         out << *m.getLeft() << *m.getRight();
  249.         return out;
  250.     }
  251.     AugmentedMatrix<L, R> *ForwardElimination(bool debug_info, int *last_step)
  252.     {
  253.         AugmentedMatrix<L, R> *U = this;
  254.         int curr_col = 0;
  255.         int rows = matrixLeft->get_rows();
  256.         for (int i = 0; i < rows; i++)
  257.         {
  258.             int row_with_max_pivot = -1;
  259.             double max_pivot = (*U->matrixLeft)[i][curr_col];
  260.             for (int j = i + 1; j < rows; j++)
  261.             {
  262.                 if ((*U->matrixLeft)[j][curr_col] == 0.0)
  263.                     continue;
  264.                 if (abs((*U->matrixLeft)[j][curr_col]) > abs(max_pivot))
  265.                 {
  266.                     row_with_max_pivot = j;
  267.                     max_pivot = (*U->matrixLeft)[j][curr_col];
  268.                 }
  269.             }
  270.             if (row_with_max_pivot != -1)
  271.             {
  272.                 if (debug_info)
  273.                     cout << "step #" << *last_step << ": permutation\n";
  274.                 PermutationMatrix *P = new PermutationMatrix(rows, i, row_with_max_pivot);
  275.                 U->matrixLeft = (L *)((*(Matrix *)P) * (*(Matrix *)U->matrixLeft));
  276.                 U->matrixRight = (R *)((*(Matrix *)P) * (*(Matrix *)U->matrixRight));
  277.                 if (debug_info)
  278.                 {
  279.                     cout << *U;
  280.                     (*last_step)++;
  281.                 }
  282.             }
  283.             for (int j = i + 1; j < rows; j++)
  284.             {
  285.                 if ((*U->matrixLeft)[j][curr_col] == 0.0 || (*U->matrixLeft)[i][curr_col] == 0.0)
  286.                     continue;
  287.                 if (debug_info)
  288.                     cout << "step #" << *last_step << ": elimination\n";
  289.                 EliminationMatrix *E = new EliminationMatrix(rows, j, i, (*U->matrixLeft)[j][curr_col] / (*U->matrixLeft)[i][curr_col]);
  290.                 U->matrixLeft = (L *)((*(Matrix *)E) * (*(Matrix *)U->matrixLeft));
  291.                 U->matrixRight = (R *)((*(Matrix *)E) * (*(Matrix *)U->matrixRight));
  292.                 if (debug_info)
  293.                 {
  294.                     cout << *U;
  295.                     (*last_step)++;
  296.                 }
  297.             }
  298.             curr_col++;
  299.         }
  300.         return U;
  301.     }
  302.     AugmentedMatrix<L, R> *BackwardElimination(bool debug_info, int *last_step)
  303.     {
  304.         AugmentedMatrix<L, R> *U = this;
  305.         int curr_col = matrixLeft->get_cols() - 1;
  306.         int rows = matrixLeft->get_rows();
  307.         for (int i = rows - 1; i >= 0; i--)
  308.         {
  309.             for (int j = i - 1; j >= 0; j--)
  310.             {
  311.                 if ((*U->matrixLeft)[j][curr_col] == 0.0 || (*U->matrixLeft)[i][curr_col] == 0.0)
  312.                     continue;
  313.                 if (debug_info)
  314.                     cout << "step #" << *last_step << ": elimination\n";
  315.                 EliminationMatrix *E = new EliminationMatrix(rows, j, i, (*U->matrixLeft)[j][curr_col] / (*U->matrixLeft)[i][curr_col]);
  316.                 U->matrixLeft = (L *)((*(Matrix *)E) * (*(Matrix *)U->matrixLeft));
  317.                 U->matrixRight = (R *)((*(Matrix *)E) * (*(Matrix *)U->matrixRight));
  318.                 if (debug_info)
  319.                 {
  320.                     cout << *U;
  321.                     (*last_step)++;
  322.                 }
  323.             }
  324.             curr_col--;
  325.         }
  326.         return U;
  327.     }
  328.     friend ColumnVector *solve(AugmentedMatrix<SquareMatrix, ColumnVector> *A, bool debug_info);
  329. };
  330.  
  331. double determinant_worker(SquareMatrix *A, bool debug_info)
  332. {
  333.     int step = 0;
  334.     AugmentedMatrix<SquareMatrix, SquareMatrix> *aug = new AugmentedMatrix(A, A);
  335.     SquareMatrix *U = aug->ForwardElimination(debug_info, &step)->getLeft();
  336.     double res = 1.0;
  337.     for (int i = 0; i < A->get_size(); i++)
  338.         res *= U->matrix[i][i];
  339.     return res;
  340. }
  341.  
  342. SquareMatrix *inverse_worker(SquareMatrix *A, bool debug_info)
  343. {
  344.     IdentityMatrix *I = new IdentityMatrix(A->get_size());
  345.     SquareMatrix *I_matrix = I;
  346.     if (debug_info)
  347.         cout << "step #0: Augmented Matrix\n";
  348.     AugmentedMatrix<SquareMatrix, SquareMatrix> *aug = new AugmentedMatrix(A, I_matrix);
  349.     if (debug_info)
  350.         cout << *aug;
  351.  
  352.     int last_step = 1;
  353.     if (debug_info)
  354.         cout << "Direct way:\n";
  355.     AugmentedMatrix<SquareMatrix, SquareMatrix> *B = aug->ForwardElimination(debug_info, &last_step);
  356.  
  357.     if (debug_info)
  358.         cout << "Way back:\n";
  359.     AugmentedMatrix<SquareMatrix, SquareMatrix> *C = B->BackwardElimination(debug_info, &last_step);
  360.  
  361.     if (debug_info)
  362.         cout << "Diagonal normalization:\n";
  363.     for (int i = 0; i < C->getLeft()->get_rows(); i++)
  364.     {
  365.         double pivot = (*C->getLeft())[i][i];
  366.         for (int j = 0; j < C->getLeft()->get_cols(); j++)
  367.             (*C->getLeft())[i][j] /= pivot;
  368.         for (int j = 0; j < C->getRight()->get_cols(); j++)
  369.             (*C->getRight())[i][j] /= pivot;
  370.     }
  371.     if (debug_info)
  372.         cout << *C;
  373.  
  374.     return (SquareMatrix *)C->getRight();
  375. }
  376.  
  377. ColumnVector *solve(AugmentedMatrix<SquareMatrix, ColumnVector> *A, bool debug_info)
  378. {
  379.     if (debug_info)
  380.     {
  381.         cout << "step #0:\n";
  382.         cout << *A;
  383.     }
  384.  
  385.     int last_step = 1;
  386.     AugmentedMatrix<SquareMatrix, ColumnVector> *B = A->ForwardElimination(debug_info, &last_step);
  387.     AugmentedMatrix<SquareMatrix, ColumnVector> *C = B->BackwardElimination(debug_info, &last_step);
  388.  
  389.     if (debug_info)
  390.         cout << "Diagonal normalization:\n";
  391.     for (int i = 0; i < C->getLeft()->get_rows(); i++)
  392.     {
  393.         double pivot = (*C->getLeft())[i][i];
  394.         if (pivot == 0.0)
  395.             continue;
  396.         for (int j = 0; j < C->getLeft()->get_cols(); j++)
  397.             (*C->getLeft())[i][j] /= pivot;
  398.         for (int j = 0; j < C->getRight()->get_cols(); j++)
  399.             (*C->getRight())[i][j] /= pivot;
  400.     }
  401.     if (debug_info)
  402.         cout << *C;
  403.  
  404.     return A->matrixRight;
  405. }
  406.  
  407. int main(void)
  408. {
  409.     cout << fixed << setprecision(4);
  410.  
  411.     int m;
  412.     cin >> m;
  413.     double t[m];
  414.     ColumnVector *b = new ColumnVector(m);
  415.     for (int i = 0; i < m; i++)
  416.     {
  417.         double t_i, b_i;
  418.         cin >> t_i >> b_i;
  419.         t[i] = t_i;
  420.         (*b->operator[](i)) = b_i;
  421.     }
  422.     int n;
  423.     cin >> n;
  424.  
  425.     Matrix *A = new Matrix(m, n + 1);
  426.     for (int i = 0; i <= n; i++)
  427.         for (int j = 0; j < m; j++)
  428.             A->operator[](j)[i] = pow(t[j], (double)i);
  429.  
  430.     cout << "A:\n";
  431.     cout << *A;
  432.  
  433.     Matrix *A_T = A->T();
  434.     SquareMatrix *A_1 = (SquareMatrix *)((*A_T) * (*A));
  435.  
  436.     cout << "A_T*A:\n";
  437.     cout << *A_1;
  438.  
  439.     SquareMatrix *A_2 = A_1->inverse(false);
  440.  
  441.     cout << "(A_T*A)^-1:\n";
  442.     cout << *A_2;
  443.  
  444.     ColumnVector *A_3 = (ColumnVector *)((*(Matrix *)A_T) * (*(Matrix *)b));
  445.  
  446.     cout << "A_T*b:\n";
  447.     cout << *A_3;
  448.  
  449.     ColumnVector *A_4 = (ColumnVector *)((*(Matrix *)A_2) * (*(Matrix *)A_3));
  450.  
  451.     cout << "x~:\n";
  452.     cout << *A_4;
  453.  
  454. #if (defined(WIN32) || defined(_WIN32)) && USE_GNUPLOT
  455.     FILE *pipe = _popen(GNUPLOT_NAME, "w");
  456. #elif USE_GNUPLOT
  457.     FILE *pipe = popen(GNUPLOT_NAME, "w");
  458. #endif
  459. #if USE_GNUPLOT
  460.     fprintf(pipe, "%s\n", "set terminal png");
  461.     fprintf(pipe, "%s\n", "set output 'output.png'");
  462.     fprintf(pipe, "%s\n", "set title \"Least Squares Approximation\"");
  463.     fprintf(pipe, "%s\n", "set key noautotitle");
  464.     fprintf(pipe, "%s\n", "set autoscale xy");
  465.     fprintf(pipe, "%s\n", "set offsets 0.05, 0.05, 0.05, 0.05");
  466.     string func;
  467.     for (int i = 0; i <= n; i++)
  468.     {
  469.         if (*A_4->operator[](i) < 0 and i != 0)
  470.             func = func.substr(0, func.size() - 1);
  471.         func += to_string(*A_4->operator[](i));
  472.         func += '*';
  473.         func += "x**";
  474.         func += to_string(i);
  475.         if (i != n)
  476.             func += '+';
  477.     }
  478.     cout << func << endl;
  479.     fprintf(pipe, "plot %s lw 3, '-' w p pt 7 ps 2\n", func.c_str());
  480.     for (int i = 0; i < m; i++)
  481.         fprintf(pipe, "%lf %lf\n", t[i], *b->operator[](i));
  482.     fprintf(pipe, "%s\n", "e");
  483.     fflush(pipe);
  484. #endif
  485. #if (defined(WIN32) || defined(_WIN32)) && USE_GNUPLOT
  486.     _pclose(pipe);
  487. #elif USE_GNUPLOT
  488.     pclose(pipe);
  489. #endif
  490.  
  491.     return 0;
  492. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement