Advertisement
Guest User

Untitled

a guest
Nov 16th, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.86 KB | None | 0 0
  1. //#pragma GCC optimize ("O3")
  2. //#pragma GCC optimize ("unroll-loops")
  3. #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,sse4a,avx,avx2")
  4.  
  5. #include <stdio.h>
  6. #include <cstdlib>
  7. #include <iostream>
  8. #include <vector>
  9. #include <random>
  10. #include <chrono>
  11. #include <immintrin.h>
  12.  
  13. const int NMAX = 1024;
  14. alignas(32) int A[NMAX * NMAX], B[NMAX * NMAX], R[NMAX * NMAX], C[NMAX * NMAX];
  15.  
  16. char getChar() {
  17. static char buffer[1 << 20];
  18. static int size = 0;
  19. static int pos = 0;
  20. if (pos == size) {
  21. size = (int)fread(buffer, 1, 1 << 20, stdin), pos = 0;
  22. }
  23. return pos == size ? -1 : buffer[pos++];
  24. }
  25.  
  26. int read() {
  27. char c = '?';
  28. while (!(c == '-' || ('0' <= c && c <= '9'))) { c = getChar(); }
  29. bool neg = false;
  30. if (c == '-') { neg = true; c = getChar(); }
  31. int ret = 0;
  32. while ('0' <= c && c <= '9') { (ret *= 10) += (c - '0'); c = getChar(); }
  33. return neg ? -ret : ret;
  34. }
  35.  
  36. const int mod = (int)1e9 + 7;
  37.  
  38. int add(int a, int b) {
  39. return (a += b) >= mod ? a - mod : a;
  40. }
  41.  
  42. int mul(int a, int b) {
  43. return int(1LL * a * b % mod);
  44. }
  45.  
  46. void rand_mat(int *m, int n) {
  47. uint64_t seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
  48. seed ^= (uint64_t)(new uint64_t);
  49. std::mt19937 gen(seed);
  50. std::uniform_int_distribution<int> dist(0, mod - 1);
  51. for (int i = 0; i < n; ++i) {
  52. for (int j = 0; j < n; ++j) {
  53. m[i*NMAX + j] = dist(gen);
  54. }
  55. }
  56. }
  57.  
  58. void input(int& n) {
  59. n = read();
  60. for (int i = 0; i < n; ++i) {
  61. for (int j = 0; j < n; ++j) {
  62. A[i*NMAX + j] = read();
  63. }
  64. }
  65. for (int i = 0; i < n; ++i) {
  66. for (int j = 0; j < n; ++j) {
  67. B[j*NMAX + i] = read();
  68. }
  69. }
  70. for (int i = 0; i < n; ++i) {
  71. for (int j = 0; j < n; ++j) {
  72. C[i*NMAX + j] = read();
  73. }
  74. }
  75. }
  76.  
  77. __m256i _mm256_load_pi32x4(const int *x) {
  78. return _mm256_cvtepi32_epi64(_mm_load_si128(reinterpret_cast<const __m128i*>(x)));
  79. }
  80.  
  81. void _mm256_store_pi64(long long * x, __m256i r) {
  82. return _mm256_store_si256(reinterpret_cast<__m256i*>(x), r);
  83. }
  84.  
  85. int dot(int n, const int *a, const int *b) {
  86. const int gsize = 16;
  87. uint64_t high = 0, low = 0;
  88. for (int g = 0; g + gsize <= n; g += gsize) {
  89. alignas(64) long long temp[4];
  90.  
  91. __m256i r1, r2, r3, r4;
  92. r1 = _mm256_mul_epi32(
  93. _mm256_load_pi32x4((int*)(a + g + 0)),
  94. _mm256_load_pi32x4((int*)(b + g + 0)));
  95.  
  96. r2 = _mm256_mul_epi32(
  97. _mm256_load_pi32x4((int*)(a + g + 4)),
  98. _mm256_load_pi32x4((int*)(b + g + 4)));
  99.  
  100. r3 = _mm256_mul_epi32(
  101. _mm256_load_pi32x4((int*)(a + g + 8)),
  102. _mm256_load_pi32x4((int*)(b + g + 8)));
  103.  
  104. r4 = _mm256_mul_epi32(
  105. _mm256_load_pi32x4((int*)(a + g + 12)),
  106. _mm256_load_pi32x4((int*)(b + g + 12)));
  107.  
  108. _mm256_store_pi64(temp, _mm256_add_epi64(
  109. _mm256_add_epi64(r1, r2), _mm256_add_epi64(r3, r4)
  110. ));
  111.  
  112. const auto old = low;
  113. low += temp[0] + temp[1];
  114. low += temp[2] + temp[3];
  115. high += (old > low);
  116. }
  117. uint64_t res = (low % mod + (((high << 32) % mod) << 32)) % mod;
  118. for (int i = n / gsize * gsize; i < n; ++i) {
  119. res += 1LL * a[i] * b[i];
  120. }
  121. return res % mod;
  122. }
  123.  
  124. void mult(int n) {
  125. for (int i = 0; i < n; ++i) {
  126. for (int j = 0; j < n; ++j) {
  127. R[i*NMAX + j] = dot(n, &A[i*NMAX], &B[j*NMAX]);
  128. }
  129. }
  130. }
  131.  
  132. bool check(const int n) {
  133. for (int i = 0; i < n; ++i) {
  134. for (int j = 0; j < n; ++j) {
  135. if (C[i * NMAX + j] != R[i * NMAX + j]) {
  136. return false;
  137. }
  138. }
  139. }
  140. return true;
  141. }
  142.  
  143. void gen(int n) {
  144. rand_mat(A, n), rand_mat(B, n);
  145. }
  146.  
  147. void test(int n) {
  148. gen(n);
  149. mult(n);
  150. for (int i = 0; i < n; ++i) {
  151. for (int j = 0; j < n; ++j) {
  152. int res = 0;
  153. for (int k = 0; k < n; ++k) {
  154. res = add(res, mul(A[i*NMAX + k], B[j*NMAX + k]));
  155. }
  156. C[i*NMAX + j] = res;
  157. }
  158. }
  159. std::cout << (check(n) ? "YES" : "NO") << std::endl;
  160. std::exit(0);
  161. }
  162.  
  163. int main() {
  164. //test(100);
  165. int n;
  166. input(n);
  167. mult(n);
  168. std::cout << (check(n) ? "YES" : "NO") << std::endl;
  169. return 0;
  170. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement