RichieHard

Untitled

Apr 21st, 2020
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.24 KB | None | 0 0
  1. /*
  2. * MatrixMath.cpp Library for Matrix Math
  3. *
  4. * Created by Charlie Matlack on 12/18/10.
  5. * Modified from code by RobH45345 on Arduino Forums, algorithm from
  6. * NUMERICAL RECIPES: The Art of Scientific Computing.
  7. */
  8.  
  9. #include "MatrixMath.h"
  10.  
  11. #define NR_END 1
  12.  
  13. MatrixMath Matrix; // Pre-instantiate
  14.  
  15. // Matrix Printing Routine
  16. // Uses tabs to separate numbers under assumption printed mtx_type width won't cause problems
  17. void MatrixMath::Print(mtx_type* A, int m, int n, String label)
  18. {
  19. // A = input matrix (m x n)
  20. int i, j;
  21. Serial.println();
  22. Serial.println(label);
  23. for (i = 0; i < m; i++)
  24. {
  25. for (j = 0; j < n; j++)
  26. {
  27. Serial.print(A[n * i + j]);
  28. Serial.print("\t");
  29. }
  30. Serial.println();
  31. }
  32. }
  33.  
  34. void MatrixMath::Copy(mtx_type* A, int n, int m, mtx_type* B)
  35. {
  36. int i, j;
  37. for (i = 0; i < m; i++)
  38. for(j = 0; j < n; j++)
  39. {
  40. B[n * i + j] = A[n * i + j];
  41. }
  42. }
  43.  
  44. //Matrix Multiplication Routine
  45. // C = A*B
  46. void MatrixMath::Multiply(mtx_type* A, mtx_type* B, int m, int p, int n, mtx_type* C)
  47. {
  48. // A = input matrix (m x p)
  49. // B = input matrix (p x n)
  50. // m = number of rows in A
  51. // p = number of columns in A = number of rows in B
  52. // n = number of columns in B
  53. // C = output matrix = A*B (m x n)
  54. int i, j, k;
  55. for (i = 0; i < m; i++)
  56. for(j = 0; j < n; j++)
  57. {
  58. C[n * i + j] = 0;
  59. for (k = 0; k < p; k++)
  60. C[n * i + j] = C[n * i + j] + A[p * i + k] * B[n * k + j];
  61. }
  62. }
  63.  
  64.  
  65. //Matrix Addition Routine
  66. void MatrixMath::Add(mtx_type* A, mtx_type* B, int m, int n, mtx_type* C)
  67. {
  68. // A = input matrix (m x n)
  69. // B = input matrix (m x n)
  70. // m = number of rows in A = number of rows in B
  71. // n = number of columns in A = number of columns in B
  72. // C = output matrix = A+B (m x n)
  73. int i, j;
  74. for (i = 0; i < m; i++)
  75. for(j = 0; j < n; j++)
  76. C[n * i + j] = A[n * i + j] + B[n * i + j];
  77. }
  78.  
  79.  
  80. //Matrix Subtraction Routine
  81. void MatrixMath::Subtract(mtx_type* A, mtx_type* B, int m, int n, mtx_type* C)
  82. {
  83. // A = input matrix (m x n)
  84. // B = input matrix (m x n)
  85. // m = number of rows in A = number of rows in B
  86. // n = number of columns in A = number of columns in B
  87. // C = output matrix = A-B (m x n)
  88. int i, j;
  89. for (i = 0; i < m; i++)
  90. for(j = 0; j < n; j++)
  91. C[n * i + j] = A[n * i + j] - B[n * i + j];
  92. }
  93.  
  94.  
  95. //Matrix Transpose Routine
  96. void MatrixMath::Transpose(mtx_type* A, int m, int n, mtx_type* C)
  97. {
  98. // A = input matrix (m x n)
  99. // m = number of rows in A
  100. // n = number of columns in A
  101. // C = output matrix = the transpose of A (n x m)
  102. int i, j;
  103. for (i = 0; i < m; i++)
  104. for(j = 0; j < n; j++)
  105. C[m * j + i] = A[n * i + j];
  106. }
  107.  
  108. void MatrixMath::Scale(mtx_type* A, int m, int n, mtx_type k)
  109. {
  110. for (int i = 0; i < m; i++)
  111. for (int j = 0; j < n; j++)
  112. A[n * i + j] = A[n * i + j] * k;
  113. }
  114.  
  115.  
  116. //Matrix Inversion Routine
  117. // * This function inverts a matrix based on the Gauss Jordan method.
  118. // * Specifically, it uses partial pivoting to improve numeric stability.
  119. // * The algorithm is drawn from those presented in
  120. // NUMERICAL RECIPES: The Art of Scientific Computing.
  121. // * The function returns 1 on success, 0 on failure.
  122. // * NOTE: The argument is ALSO the result matrix, meaning the input matrix is REPLACED
  123. int MatrixMath::Invert(mtx_type* A, int n)
  124. {
  125. // A = input matrix AND result matrix
  126. // n = number of rows = number of columns in A (n x n)
  127. int pivrow = 0; // keeps track of current pivot row
  128. int k, i, j; // k: overall index along diagonal; i: row index; j: col index
  129. int pivrows[n]; // keeps track of rows swaps to undo at end
  130. mtx_type tmp; // used for finding max value and making column swaps
  131.  
  132. for (k = 0; k < n; k++)
  133. {
  134. // find pivot row, the row with biggest entry in current column
  135. tmp = 0;
  136. for (i = k; i < n; i++)
  137. {
  138. if (abs(A[i * n + k]) >= tmp) // 'Avoid using other functions inside abs()?'
  139. {
  140. tmp = abs(A[i * n + k]);
  141. pivrow = i;
  142. }
  143. }
  144.  
  145. // check for singular matrix
  146. if (A[pivrow * n + k] == 0.0f)
  147. {
  148. Serial.println("Inversion failed due to singular matrix");
  149. return 0;
  150. }
  151.  
  152. // Execute pivot (row swap) if needed
  153. if (pivrow != k)
  154. {
  155. // swap row k with pivrow
  156. for (j = 0; j < n; j++)
  157. {
  158. tmp = A[k * n + j];
  159. A[k * n + j] = A[pivrow * n + j];
  160. A[pivrow * n + j] = tmp;
  161. }
  162. }
  163. pivrows[k] = pivrow; // record row swap (even if no swap happened)
  164.  
  165. tmp = 1.0f / A[k * n + k]; // invert pivot element
  166. A[k * n + k] = 1.0f; // This element of input matrix becomes result matrix
  167.  
  168. // Perform row reduction (divide every element by pivot)
  169. for (j = 0; j < n; j++)
  170. {
  171. A[k * n + j] = A[k * n + j] * tmp;
  172. }
  173.  
  174. // Now eliminate all other entries in this column
  175. for (i = 0; i < n; i++)
  176. {
  177. if (i != k)
  178. {
  179. tmp = A[i * n + k];
  180. A[i * n + k] = 0.0f; // The other place where in matrix becomes result mat
  181. for (j = 0; j < n; j++)
  182. {
  183. A[i * n + j] = A[i * n + j] - A[k * n + j] * tmp;
  184. }
  185. }
  186. }
  187. }
  188.  
  189. // Done, now need to undo pivot row swaps by doing column swaps in reverse order
  190. for (k = n - 1; k >= 0; k--)
  191. {
  192. if (pivrows[k] != k)
  193. {
  194. for (i = 0; i < n; i++)
  195. {
  196. tmp = A[i * n + k];
  197. A[i * n + k] = A[i * n + pivrows[k]];
  198. A[i * n + pivrows[k]] = tmp;
  199. }
  200. }
  201. }
  202. return 1;
  203. }
Advertisement
Add Comment
Please, Sign In to add comment