Advertisement
Guest User

Untitled

a guest
Nov 22nd, 2019
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 20.01 KB | None | 0 0
  1. #include <time.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4.  
  5. #include <mmintrin.h>
  6. #include <xmmintrin.h> // SSE
  7. #include <pmmintrin.h> // SSE2
  8. #include <emmintrin.h> // SSE3
  9.  
  10. #define T float
  11.  
  12. T* make_matrix(int n)
  13. {
  14. T* mat = (T*)malloc(sizeof(T)*n*n);
  15. for (int i = 0; i < n*n; ++i)
  16. {
  17. mat[i] = rand() / RAND_MAX;
  18. }
  19. return mat;
  20. }
  21.  
  22. void destroy_matrix(T* m)
  23. {
  24. free(m);
  25. }
  26.  
  27. void mm0(T* lhs, T* rhs, T* result, int n)
  28. {
  29. for (int i = 0; i < n; ++i)
  30. {
  31. for (int j = 0; j < n; ++j)
  32. {
  33. T sum = 0;
  34. for (int k = 0; k < n; ++k)
  35. {
  36. sum += lhs[i*n+k] * rhs[k*n+j];
  37. }
  38. result[i*n+j] = sum;
  39. }
  40. }
  41. }
  42.  
  43. void mm1(T* lhs, T* rhs, T* result, int n)
  44. {
  45. for (int i = n; --i;)
  46. {
  47. for (int j = n; --j;)
  48. {
  49. T sum = 0;
  50. for (int k = n; --k;)
  51. {
  52. sum += lhs[i*n+k] * rhs[k*n+j];
  53. }
  54. result[i*n+j] = sum;
  55. }
  56. }
  57. }
  58.  
  59. void mm2(T* lhs, T* rhs, T* result, int n)
  60. {
  61. for (int i = n; --i;)
  62. {
  63. for (int j = n; --j;)
  64. {
  65. T sum = 0;
  66. int k = 0;
  67. for (; k < n - 7; k += 8)
  68. {
  69. sum += lhs[i*n+k] * rhs[k*n+j];
  70. sum += lhs[i*n+k+1] * rhs[(k+1)*n+j];
  71. sum += lhs[i*n+k+2] * rhs[(k+2)*n+j];
  72. sum += lhs[i*n+k+3] * rhs[(k+3)*n+j];
  73. sum += lhs[i*n+k+4] * rhs[(k+4)*n+j];
  74. sum += lhs[i*n+k+5] * rhs[(k+5)*n+j];
  75. sum += lhs[i*n+k+6] * rhs[(k+6)*n+j];
  76. sum += lhs[i*n+k+7] * rhs[(k+7)*n+j];
  77. }
  78. for (; k < n; ++k)
  79. {
  80. sum += lhs[i*n+k] * rhs[k*n+j];
  81. }
  82. result[i*n+j] = sum;
  83. }
  84. }
  85. }
  86.  
  87. void mm3(T* lhs, T* rhs, T* result, int n)
  88. {
  89. for (int i = n; --i;)
  90. {
  91. for (int j = n; --j;)
  92. {
  93. T sum = 0;
  94. int k = 0;
  95. for (; k < n - 15; k += 16)
  96. {
  97. sum += lhs[i*n+k] * rhs[k*n+j];
  98. sum += lhs[i*n+k+1] * rhs[(k+1)*n+j];
  99. sum += lhs[i*n+k+2] * rhs[(k+2)*n+j];
  100. sum += lhs[i*n+k+3] * rhs[(k+3)*n+j];
  101. sum += lhs[i*n+k+4] * rhs[(k+4)*n+j];
  102. sum += lhs[i*n+k+5] * rhs[(k+5)*n+j];
  103. sum += lhs[i*n+k+6] * rhs[(k+6)*n+j];
  104. sum += lhs[i*n+k+7] * rhs[(k+7)*n+j];
  105. sum += lhs[i*n+k+8] * rhs[(k+8)*n+j];
  106. sum += lhs[i*n+k+9] * rhs[(k+9)*n+j];
  107. sum += lhs[i*n+k+10] * rhs[(k+10)*n+j];
  108. sum += lhs[i*n+k+11] * rhs[(k+11)*n+j];
  109. sum += lhs[i*n+k+12] * rhs[(k+12)*n+j];
  110. sum += lhs[i*n+k+13] * rhs[(k+13)*n+j];
  111. sum += lhs[i*n+k+14] * rhs[(k+14)*n+j];
  112. sum += lhs[i*n+k+15] * rhs[(k+15)*n+j];
  113. }
  114. for (; k < n; ++k)
  115. {
  116. sum += lhs[i*n+k] * rhs[k*n+j];
  117. }
  118. result[i*n+j] = sum;
  119. }
  120. }
  121. }
  122.  
  123. void mm4(T* lhs, T* rhs, T* result, int n)
  124. {
  125. for (int i = n; --i;)
  126. {
  127. for (int j = n; --j;)
  128. {
  129. T sum = 0;
  130. int k = 0;
  131. for (; k < n - 31; k += 32)
  132. {
  133. sum += lhs[i*n+k] * rhs[k*n+j];
  134. sum += lhs[i*n+k+1] * rhs[(k+1)*n+j];
  135. sum += lhs[i*n+k+2] * rhs[(k+2)*n+j];
  136. sum += lhs[i*n+k+3] * rhs[(k+3)*n+j];
  137. sum += lhs[i*n+k+4] * rhs[(k+4)*n+j];
  138. sum += lhs[i*n+k+5] * rhs[(k+5)*n+j];
  139. sum += lhs[i*n+k+6] * rhs[(k+6)*n+j];
  140. sum += lhs[i*n+k+7] * rhs[(k+7)*n+j];
  141. sum += lhs[i*n+k+8] * rhs[(k+8)*n+j];
  142. sum += lhs[i*n+k+9] * rhs[(k+9)*n+j];
  143. sum += lhs[i*n+k+10] * rhs[(k+10)*n+j];
  144. sum += lhs[i*n+k+11] * rhs[(k+11)*n+j];
  145. sum += lhs[i*n+k+12] * rhs[(k+12)*n+j];
  146. sum += lhs[i*n+k+13] * rhs[(k+13)*n+j];
  147. sum += lhs[i*n+k+14] * rhs[(k+14)*n+j];
  148. sum += lhs[i*n+k+15] * rhs[(k+15)*n+j];
  149. sum += lhs[i*n+k+16] * rhs[(k+16)*n+j];
  150. sum += lhs[i*n+k+17] * rhs[(k+17)*n+j];
  151. sum += lhs[i*n+k+18] * rhs[(k+18)*n+j];
  152. sum += lhs[i*n+k+19] * rhs[(k+19)*n+j];
  153. sum += lhs[i*n+k+20] * rhs[(k+20)*n+j];
  154. sum += lhs[i*n+k+21] * rhs[(k+21)*n+j];
  155. sum += lhs[i*n+k+22] * rhs[(k+22)*n+j];
  156. sum += lhs[i*n+k+23] * rhs[(k+23)*n+j];
  157. sum += lhs[i*n+k+24] * rhs[(k+24)*n+j];
  158. sum += lhs[i*n+k+25] * rhs[(k+25)*n+j];
  159. sum += lhs[i*n+k+26] * rhs[(k+26)*n+j];
  160. sum += lhs[i*n+k+27] * rhs[(k+27)*n+j];
  161. sum += lhs[i*n+k+28] * rhs[(k+28)*n+j];
  162. sum += lhs[i*n+k+29] * rhs[(k+29)*n+j];
  163. sum += lhs[i*n+k+30] * rhs[(k+30)*n+j];
  164. sum += lhs[i*n+k+31] * rhs[(k+31)*n+j];
  165. }
  166. for (; k < n; ++k)
  167. {
  168. sum += lhs[i*n+k] * rhs[k*n+j];
  169. }
  170. result[i*n+j] = sum;
  171. }
  172. }
  173. }
  174.  
  175. void transpose(T* from, T* to, int n)
  176. {
  177. for (int i = 0; i < n; ++i)
  178. {
  179. for (int j = 0; j < n; ++j)
  180. {
  181. to[i*n+j] = from[i+j*n];
  182. }
  183. }
  184. }
  185.  
  186. void mm5(T* lhs, T* rhs, T* result, int n)
  187. {
  188. T* rhs_t = make_matrix(n);
  189. transpose(rhs, rhs_t, n);
  190.  
  191. for (int i = n; --i;)
  192. {
  193. for (int j = n; --j;)
  194. {
  195. T sum = 0;
  196. int k = 0;
  197. for (; k < n - 31; k += 32)
  198. {
  199. sum += lhs[i*n+k] * rhs_t[k+j*n];
  200. sum += lhs[i*n+k+1] * rhs_t[(k+1)+j*n];
  201. sum += lhs[i*n+k+2] * rhs_t[(k+2)+j*n];
  202. sum += lhs[i*n+k+3] * rhs_t[(k+3)+j*n];
  203. sum += lhs[i*n+k+4] * rhs_t[(k+4)+j*n];
  204. sum += lhs[i*n+k+5] * rhs_t[(k+5)+j*n];
  205. sum += lhs[i*n+k+6] * rhs_t[(k+6)+j*n];
  206. sum += lhs[i*n+k+7] * rhs_t[(k+7)+j*n];
  207. sum += lhs[i*n+k+8] * rhs_t[(k+8)+j*n];
  208. sum += lhs[i*n+k+9] * rhs_t[(k+9)+j*n];
  209. sum += lhs[i*n+k+10] * rhs_t[(k+10)+j*n];
  210. sum += lhs[i*n+k+11] * rhs_t[(k+11)+j*n];
  211. sum += lhs[i*n+k+12] * rhs_t[(k+12)+j*n];
  212. sum += lhs[i*n+k+13] * rhs_t[(k+13)+j*n];
  213. sum += lhs[i*n+k+14] * rhs_t[(k+14)+j*n];
  214. sum += lhs[i*n+k+15] * rhs_t[(k+15)+j*n];
  215. sum += lhs[i*n+k+16] * rhs_t[(k+16)+j*n];
  216. sum += lhs[i*n+k+17] * rhs_t[(k+17)+j*n];
  217. sum += lhs[i*n+k+18] * rhs_t[(k+18)+j*n];
  218. sum += lhs[i*n+k+19] * rhs_t[(k+19)+j*n];
  219. sum += lhs[i*n+k+20] * rhs_t[(k+20)+j*n];
  220. sum += lhs[i*n+k+21] * rhs_t[(k+21)+j*n];
  221. sum += lhs[i*n+k+22] * rhs_t[(k+22)+j*n];
  222. sum += lhs[i*n+k+23] * rhs_t[(k+23)+j*n];
  223. sum += lhs[i*n+k+24] * rhs_t[(k+24)+j*n];
  224. sum += lhs[i*n+k+25] * rhs_t[(k+25)+j*n];
  225. sum += lhs[i*n+k+26] * rhs_t[(k+26)+j*n];
  226. sum += lhs[i*n+k+27] * rhs_t[(k+27)+j*n];
  227. sum += lhs[i*n+k+28] * rhs_t[(k+28)+j*n];
  228. sum += lhs[i*n+k+29] * rhs_t[(k+29)+j*n];
  229. sum += lhs[i*n+k+30] * rhs_t[(k+30)+j*n];
  230. sum += lhs[i*n+k+31] * rhs_t[(k+31)+j*n];
  231. }
  232. for (; k < n; ++k)
  233. {
  234. sum += lhs[i*n+k] * rhs_t[k+j*n];
  235. }
  236. result[i*n+j] = sum;
  237. }
  238. }
  239.  
  240. destroy_matrix(rhs_t);
  241. }
  242.  
  243. void mm6(T* lhs, T* rhs, T* result, int n)
  244. {
  245. T* rhs_t = make_matrix(n);
  246. transpose(rhs, rhs_t, n);
  247.  
  248. for (int i = n; --i;)
  249. {
  250. for (int j = n; --j;)
  251. {
  252. T sum0 = 0;
  253. T sum1 = 0;
  254. T sum2 = 0;
  255. T sum3 = 0;
  256. int k = 0;
  257. for (; k < n - 7; k += 8)
  258. {
  259. sum0 += lhs[i*n+k] * rhs_t[k+j*n];
  260. sum1 += lhs[i*n+k+1] * rhs_t[(k+1)+j*n];
  261. sum2 += lhs[i*n+k+2] * rhs_t[(k+2)+j*n];
  262. sum3 += lhs[i*n+k+3] * rhs_t[(k+3)+j*n];
  263. sum0 += lhs[i*n+k+4] * rhs_t[(k+4)+j*n];
  264. sum1 += lhs[i*n+k+5] * rhs_t[(k+5)+j*n];
  265. sum2 += lhs[i*n+k+6] * rhs_t[(k+6)+j*n];
  266. sum3 += lhs[i*n+k+7] * rhs_t[(k+7)+j*n];
  267. }
  268. for (; k < n; ++k)
  269. {
  270. sum0 += lhs[i*n+k] * rhs_t[k+j*n];
  271. }
  272. result[i*n+j] = sum0 + sum1 + sum2 + sum3;
  273. }
  274. }
  275.  
  276. destroy_matrix(rhs_t);
  277. }
  278.  
  279. void mm7(T* lhs, T* rhs, T* result, int n)
  280. {
  281. T* rhs_t = make_matrix(n);
  282. transpose(rhs, rhs_t, n);
  283.  
  284. for (int i = 0; i < n; ++i)
  285. {
  286. int j = 0;
  287. for (; j < n - 7; j += 4)
  288. {
  289. T sum0 = 0;
  290. T sum1 = 0;
  291. T sum2 = 0;
  292. T sum3 = 0;
  293. int k = 0;
  294. for (; k < n - 7; k += 8)
  295. {
  296. const T a0 = lhs[i*n+k+0];
  297. const T a1 = lhs[i*n+k+1];
  298. const T a2 = lhs[i*n+k+2];
  299. const T a3 = lhs[i*n+k+3];
  300. const T a4 = lhs[i*n+k+4];
  301. const T a5 = lhs[i*n+k+5];
  302. const T a6 = lhs[i*n+k+6];
  303. const T a7 = lhs[i*n+k+7];
  304.  
  305. sum0 += a0 * rhs_t[(k+0)+j*n];
  306. sum1 += a0 * rhs_t[(k+0)+(j+1)*n];
  307. sum2 += a0 * rhs_t[(k+0)+(j+2)*n];
  308. sum3 += a0 * rhs_t[(k+0)+(j+3)*n];
  309.  
  310. sum0 += a1 * rhs_t[(k+1)+j*n];
  311. sum1 += a1 * rhs_t[(k+1)+(j+1)*n];
  312. sum2 += a1 * rhs_t[(k+1)+(j+2)*n];
  313. sum3 += a1 * rhs_t[(k+1)+(j+3)*n];
  314.  
  315. sum0 += a2 * rhs_t[(k+2)+j*n];
  316. sum1 += a2 * rhs_t[(k+2)+(j+1)*n];
  317. sum2 += a2 * rhs_t[(k+2)+(j+2)*n];
  318. sum3 += a2 * rhs_t[(k+2)+(j+3)*n];
  319.  
  320. sum0 += a3 * rhs_t[(k+3)+j*n];
  321. sum1 += a3 * rhs_t[(k+3)+(j+1)*n];
  322. sum2 += a3 * rhs_t[(k+3)+(j+2)*n];
  323. sum3 += a3 * rhs_t[(k+3)+(j+3)*n];
  324.  
  325. sum0 += a4 * rhs_t[(k+4)+j*n];
  326. sum1 += a4 * rhs_t[(k+4)+(j+1)*n];
  327. sum2 += a4 * rhs_t[(k+4)+(j+2)*n];
  328. sum3 += a4 * rhs_t[(k+4)+(j+3)*n];
  329.  
  330. sum0 += a5 * rhs_t[(k+5)+j*n];
  331. sum1 += a5 * rhs_t[(k+5)+(j+1)*n];
  332. sum2 += a5 * rhs_t[(k+5)+(j+2)*n];
  333. sum3 += a5 * rhs_t[(k+5)+(j+3)*n];
  334.  
  335. sum0 += a6 * rhs_t[(k+6)+j*n];
  336. sum1 += a6 * rhs_t[(k+6)+(j+1)*n];
  337. sum2 += a6 * rhs_t[(k+6)+(j+2)*n];
  338. sum3 += a6 * rhs_t[(k+6)+(j+3)*n];
  339.  
  340. sum0 += a7 * rhs_t[(k+7)+j*n];
  341. sum1 += a7 * rhs_t[(k+7)+(j+1)*n];
  342. sum2 += a7 * rhs_t[(k+7)+(j+2)*n];
  343. sum3 += a7 * rhs_t[(k+7)+(j+3)*n];
  344. }
  345. result[i*n+j] = sum0;
  346. result[i*n+j+1] = sum1;
  347. result[i*n+j+2] = sum2;
  348. result[i*n+j+3] = sum3;
  349. }
  350. }
  351.  
  352. destroy_matrix(rhs_t);
  353. }
  354.  
  355. void mm8_d(double* lhs, double* rhs, double* result, int n)
  356. {
  357. double* rhs_t = make_matrix(n);
  358. transpose(rhs, rhs_t, n);
  359.  
  360. for (int i = 0; i < n; ++i)
  361. {
  362. int j = 0;
  363. for (; j < n - 7; j += 8)
  364. {
  365. __m128d sum0 = _mm_setzero_pd();
  366. __m128d sum1 = _mm_setzero_pd();
  367. __m128d sum2 = _mm_setzero_pd();
  368. __m128d sum3 = _mm_setzero_pd();
  369. __m128d sum4 = _mm_setzero_pd();
  370. __m128d sum5 = _mm_setzero_pd();
  371. __m128d sum6 = _mm_setzero_pd();
  372. __m128d sum7 = _mm_setzero_pd();
  373. int k = 0;
  374. for (; k < n - 7; k += 8)
  375. {
  376. const __m128d a01 = _mm_load_pd(&lhs[i*n+k+0]);
  377. const __m128d a23 = _mm_load_pd(&lhs[i*n+k+2]);
  378. const __m128d a45 = _mm_load_pd(&lhs[i*n+k+4]);
  379. const __m128d a67 = _mm_load_pd(&lhs[i*n+k+6]);
  380.  
  381. sum0 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+j*n]));
  382. sum0 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+j*n]));
  383. sum0 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+j*n]));
  384. sum0 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+j*n]));
  385.  
  386. sum1 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+1)*n]));
  387. sum1 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+1)*n]));
  388. sum1 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+1)*n]));
  389. sum1 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+1)*n]));
  390.  
  391. sum2 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+2)*n]));
  392. sum2 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+2)*n]));
  393. sum2 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+2)*n]));
  394. sum2 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+2)*n]));
  395.  
  396. sum3 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+3)*n]));
  397. sum3 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+3)*n]));
  398. sum3 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+3)*n]));
  399. sum3 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+3)*n]));
  400.  
  401. sum4 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+4)*n]));
  402. sum4 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+4)*n]));
  403. sum4 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+4)*n]));
  404. sum4 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+4)*n]));
  405.  
  406. sum5 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+5)*n]));
  407. sum5 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+5)*n]));
  408. sum5 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+5)*n]));
  409. sum1 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+5)*n]));
  410.  
  411. sum6 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+6)*n]));
  412. sum6 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+6)*n]));
  413. sum6 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+6)*n]));
  414. sum6 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+6)*n]));
  415.  
  416. sum7 += _mm_mul_pd(a01, _mm_load_pd(&rhs_t[(k+0)+(j+7)*n]));
  417. sum7 += _mm_mul_pd(a23, _mm_load_pd(&rhs_t[(k+2)+(j+7)*n]));
  418. sum7 += _mm_mul_pd(a45, _mm_load_pd(&rhs_t[(k+4)+(j+7)*n]));
  419. sum7 += _mm_mul_pd(a67, _mm_load_pd(&rhs_t[(k+6)+(j+7)*n]));
  420. }
  421. _mm_store_pd(&result[i*n+j], _mm_hadd_pd(sum0, sum1));
  422. _mm_store_pd(&result[i*n+j+2], _mm_hadd_pd(sum2, sum3));
  423. _mm_store_pd(&result[i*n+j+4], _mm_hadd_pd(sum4, sum5));
  424. _mm_store_pd(&result[i*n+j+6], _mm_hadd_pd(sum6, sum7));
  425. }
  426. }
  427.  
  428. destroy_matrix(rhs_t);
  429. }
  430.  
  431.  
  432. void mm8_f(float* lhs, float* rhs, float* result, int n)
  433. {
  434. float* rhs_t = make_matrix(n);
  435. transpose(rhs, rhs_t, n);
  436.  
  437. for (int i = 0; i < n; ++i)
  438. {
  439. int j = 0;
  440. for (; j < n - 7; j += 8)
  441. {
  442. __m128 sum0 = _mm_setzero_ps();
  443. __m128 sum1 = _mm_setzero_ps();
  444. __m128 sum2 = _mm_setzero_ps();
  445. __m128 sum3 = _mm_setzero_ps();
  446. __m128 sum4 = _mm_setzero_ps();
  447. __m128 sum5 = _mm_setzero_ps();
  448. __m128 sum6 = _mm_setzero_ps();
  449. __m128 sum7 = _mm_setzero_ps();
  450. int k = 0;
  451. for (; k < n - 15; k += 16)
  452. {
  453. const __m128 a0123 = _mm_load_ps(&lhs[i*n+k+0]);
  454. const __m128 a4567 = _mm_load_ps(&lhs[i*n+k+4]);
  455. const __m128 b0123 = _mm_load_ps(&lhs[i*n+k+8]);
  456. const __m128 b4567 = _mm_load_ps(&lhs[i*n+k+12]);
  457.  
  458. sum0 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+j*n]));
  459. sum0 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+j*n]));
  460. sum0 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+j*n]));
  461. sum0 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+j*n]));
  462.  
  463. sum1 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+1)*n]));
  464. sum1 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+1)*n]));
  465. sum1 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+1)*n]));
  466. sum1 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+1)*n]));
  467.  
  468. sum2 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+2)*n]));
  469. sum2 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+2)*n]));
  470. sum3 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+2)*n]));
  471. sum3 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+2)*n]));
  472.  
  473. sum3 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+3)*n]));
  474. sum3 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+3)*n]));
  475. sum4 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+3)*n]));
  476. sum4 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+3)*n]));
  477.  
  478. sum4 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+4)*n]));
  479. sum4 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+4)*n]));
  480. sum4 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+4)*n]));
  481. sum4 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+4)*n]));
  482.  
  483. sum5 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+5)*n]));
  484. sum5 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+5)*n]));
  485. sum5 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+5)*n]));
  486. sum5 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+5)*n]));
  487.  
  488. sum6 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+6)*n]));
  489. sum6 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+6)*n]));
  490. sum6 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+6)*n]));
  491. sum6 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+6)*n]));
  492.  
  493. sum7 += _mm_mul_ps(a0123, _mm_load_ps(&rhs_t[(k+0)+(j+7)*n]));
  494. sum7 += _mm_mul_ps(a4567, _mm_load_ps(&rhs_t[(k+4)+(j+7)*n]));
  495. sum7 += _mm_mul_ps(b0123, _mm_load_ps(&rhs_t[(k+8)+(j+7)*n]));
  496. sum7 += _mm_mul_ps(b4567, _mm_load_ps(&rhs_t[(k+12)+(j+7)*n]));
  497. }
  498. __m128 t0 = _mm_hadd_ps(sum0, sum1);
  499. __m128 t1 = _mm_hadd_ps(sum2, sum3);
  500. __m128 t2 = _mm_hadd_ps(sum4, sum5);
  501. __m128 t3 = _mm_hadd_ps(sum6, sum7);
  502. _mm_store_ps(&result[i*n+j], _mm_hadd_ps(t0, t1));
  503. _mm_store_ps(&result[i*n+j+4], _mm_hadd_ps(t2, t3));
  504. }
  505. }
  506.  
  507. destroy_matrix(rhs_t);
  508. }
  509.  
  510. float time_it(void (*f)(T*, T*, T*, int), int n, float min_t)
  511. {
  512. T* lhs = make_matrix(n);
  513. T* rhs = make_matrix(n);
  514. T* result = make_matrix(n);
  515.  
  516. clock_t start = clock();
  517. for(int i = 1;; ++i)
  518. {
  519. f(lhs, rhs, result, n);
  520.  
  521. clock_t end = clock();
  522. float seconds = (float)(end - start) / CLOCKS_PER_SEC;
  523. if (seconds > min_t)
  524. {
  525. return seconds / i;
  526. }
  527. }
  528.  
  529. destroy_matrix(lhs);
  530. destroy_matrix(rhs);
  531. destroy_matrix(result);
  532. }
  533.  
  534. int main()
  535. {
  536. const int n = 512;
  537. const float min_t = 4.0f;
  538. /*
  539. printf("mm0: %f\n", time_it(mm0, n, min_t));
  540. printf("mm1: %f\n", time_it(mm1, n, min_t));
  541. printf("mm2: %f\n", time_it(mm2, n, min_t));
  542. printf("mm3: %f\n", time_it(mm3, n, min_t));
  543. printf("mm4: %f\n", time_it(mm4, n, min_t));
  544. printf("mm5: %f\n", time_it(mm5, n, min_t));
  545. printf("mm6: %f\n", time_it(mm6, n, min_t));
  546. printf("mm7: %f\n", time_it(mm7, n, min_t));
  547. printf("mm8: %f\n", time_it(mm8_f, n, min_t));
  548. */
  549. printf("mm8: %f\n", time_it(mm8, n, min_t));
  550. return 0;
  551. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement