Advertisement
Guest User

atomic fail

a guest
Mar 22nd, 2023
34
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 21.67 KB | None | 0 0
  1. #include <torch/all.h>
  2. #include <torch/python.h>
  3. #include <cuda.h>
  4. #include <cuda_runtime.h>
  5. #include <cuda_fp16.h>
  6.  
  7. // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
  8. __device__ __forceinline__ void atomicAdd2(__half* address, c10::Half val) {
  9. unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
  10. unsigned int old = *address_as_ui;
  11. unsigned int assumed;
  12.  
  13. do {
  14. assumed = old;
  15. unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
  16. hsum += val;
  17. old = reinterpret_cast<size_t>(address) & 2
  18. ? (old & 0xffff) | (hsum << 16)
  19. : (old & 0xffff0000) | hsum;
  20. old = atomicCAS(address_as_ui, assumed, old);
  21.  
  22. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  23. } while (assumed != old);
  24. }
  25.  
  26. template <typename scalar_t>
  27. __global__ void VecQuant2MatMulKernel(
  28. const scalar_t* __restrict__ vec,
  29. const int* __restrict__ mat,
  30. scalar_t* __restrict__ mul,
  31. const scalar_t* __restrict__ scales,
  32. const scalar_t* __restrict__ zeros,
  33. int batch,
  34. int vec_height,
  35. int height,
  36. int width
  37. );
  38.  
  39. template <typename scalar_t>
  40. __global__ void VecQuant3MatMulKernel(
  41. const scalar_t* __restrict__ vec,
  42. const int* __restrict__ mat,
  43. scalar_t* __restrict__ mul,
  44. const scalar_t* __restrict__ scales,
  45. const scalar_t* __restrict__ zeros,
  46. int batch,
  47. int vec_height,
  48. int height,
  49. int width
  50. );
  51.  
  52. template <typename scalar_t>
  53. __global__ void VecQuant4MatMulKernel(
  54. const scalar_t* __restrict__ vec,
  55. const int* __restrict__ mat,
  56. scalar_t* __restrict__ mul,
  57. const scalar_t* __restrict__ scales,
  58. const scalar_t* __restrict__ zeros,
  59. int batch,
  60. int vec_height,
  61. int height,
  62. int width
  63. );
  64.  
  65. template <typename scalar_t>
  66. __global__ void VecQuant8MatMulKernel(
  67. const scalar_t* __restrict__ vec,
  68. const int* __restrict__ mat,
  69. scalar_t* __restrict__ mul,
  70. const scalar_t* __restrict__ scales,
  71. const scalar_t* __restrict__ zeros,
  72. int batch,
  73. int vec_height,
  74. int height,
  75. int width
  76. );
  77.  
  78. const int BLOCKWIDTH = 256;
  79. const int BLOCKHEIGHT2 = 16;
  80. const int BLOCKHEIGHT3 = 24;
  81. const int BLOCKHEIGHT4 = 32;
  82. const int BLOCKHEIGHT8 = 64;
  83.  
  84. __device__ inline unsigned int as_unsigned(int i) {
  85. return *reinterpret_cast<unsigned int*>(&i);
  86. }
  87.  
  88. void vecquant2matmul_cuda(
  89. torch::Tensor vec,
  90. torch::Tensor mat,
  91. torch::Tensor mul,
  92. torch::Tensor scales,
  93. torch::Tensor zeros
  94. ) {
  95. int batch = vec.size(0);
  96. int vec_height = vec.size(1);
  97. int height = mat.size(0);
  98. int width = mat.size(1);
  99.  
  100. dim3 blocks(
  101. (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
  102. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  103. batch
  104. );
  105. dim3 threads(BLOCKWIDTH);
  106.  
  107. AT_DISPATCH_FLOATING_TYPES(
  108. vec.type(), "vecquant2matmul_cuda", ([&] {
  109. VecQuant2MatMulKernel<<<blocks, threads>>>(
  110. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  111. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  112. batch, vec_height, height, width
  113. );
  114. })
  115. );
  116. }
  117.  
  118. template <typename scalar_t>
  119. __global__ void VecQuant2MatMulKernel(
  120. const scalar_t* __restrict__ vec,
  121. const int* __restrict__ mat,
  122. scalar_t* __restrict__ mul,
  123. const scalar_t* __restrict__ scales,
  124. const scalar_t* __restrict__ zeros,
  125. int batch,
  126. int vec_height,
  127. int height,
  128. int width
  129. ) {
  130. int b = blockIdx.z;
  131. int h = BLOCKHEIGHT2 * blockIdx.x;
  132. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  133.  
  134. __shared__ scalar_t blockvec[BLOCKWIDTH];
  135. blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x];
  136. __syncthreads();
  137.  
  138. scalar_t scale = scales[w];
  139. scalar_t zero = zeros[w];
  140.  
  141. scalar_t res = 0;
  142. int i = width * h + w;
  143. int k = 0;
  144.  
  145. unsigned int tmp;
  146.  
  147. while (k < BLOCKWIDTH) {
  148. tmp = as_unsigned(mat[i]);
  149. res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
  150. res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
  151. res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
  152. res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
  153. res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
  154. res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
  155. res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
  156. res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
  157. res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
  158. res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
  159. res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
  160. res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
  161. res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
  162. res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
  163. res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
  164. res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
  165. i += width;
  166. k += 16;
  167. }
  168.  
  169. atomicAdd(&mul[b * width + w], res);
  170. }
  171.  
  172. void vecquant3matmul_cuda(
  173. torch::Tensor vec,
  174. torch::Tensor mat,
  175. torch::Tensor mul,
  176. torch::Tensor scales,
  177. torch::Tensor zeros
  178. ) {
  179. int batch = vec.size(0);
  180. int vec_height = vec.size(1);
  181. int height = mat.size(0);
  182. int width = mat.size(1);
  183.  
  184. dim3 blocks(
  185. (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
  186. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  187. batch
  188. );
  189. dim3 threads(BLOCKWIDTH);
  190.  
  191. AT_DISPATCH_FLOATING_TYPES(
  192. vec.type(), "vecquant3matmul_cuda", ([&] {
  193. VecQuant3MatMulKernel<<<blocks, threads>>>(
  194. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  195. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  196. batch, vec_height, height, width
  197. );
  198. })
  199. );
  200. }
  201.  
  202. template <typename scalar_t>
  203. __global__ void VecQuant3MatMulKernel(
  204. const scalar_t* __restrict__ vec,
  205. const int* __restrict__ mat,
  206. scalar_t* __restrict__ mul,
  207. const scalar_t* __restrict__ scales,
  208. const scalar_t* __restrict__ zeros,
  209. int batch,
  210. int vec_height,
  211. int height,
  212. int width
  213. ) {
  214. int b = blockIdx.z;
  215. int h = BLOCKHEIGHT3 * blockIdx.x;
  216. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  217.  
  218. __shared__ scalar_t blockvec[BLOCKWIDTH];
  219. blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x];
  220. __syncthreads();
  221.  
  222. scalar_t scale = scales[w];
  223. scalar_t zero = zeros[w];
  224.  
  225. scalar_t res = 0;
  226. int i = width * h + w;
  227. int k = 0;
  228.  
  229. unsigned int tmp1;
  230. unsigned int tmp2;
  231. unsigned int tmp;
  232.  
  233. while (k < BLOCKWIDTH) {
  234. tmp1 = as_unsigned(mat[i]);
  235. res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
  236. res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
  237. res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
  238. res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
  239. res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
  240. res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
  241. res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
  242. res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
  243. res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
  244. res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
  245. i += width;
  246. tmp2 = as_unsigned(mat[i]);
  247. tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
  248. tmp2 >>= 1;
  249. res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
  250. k += 11;
  251. res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
  252. res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
  253. res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
  254. res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
  255. res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
  256. res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
  257. res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
  258. res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
  259. res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
  260. res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
  261. i += width;
  262. tmp1 = as_unsigned(mat[i]);
  263. tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
  264. tmp1 >>= 2;
  265. res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
  266. k += 11;
  267. res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
  268. res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
  269. res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
  270. res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
  271. res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
  272. res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
  273. res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
  274. res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
  275. res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
  276. res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
  277. i += width;
  278. k += 10;
  279. }
  280.  
  281. atomicAdd(&mul[b * width + w], res);
  282. }
  283.  
  284. void vecquant4matmul_cuda(
  285. torch::Tensor vec,
  286. torch::Tensor mat,
  287. torch::Tensor mul,
  288. torch::Tensor scales,
  289. torch::Tensor zeros
  290. ) {
  291. int batch = vec.size(0);
  292. int vec_height = vec.size(1);
  293. int height = mat.size(0);
  294. int width = mat.size(1);
  295.  
  296. dim3 blocks(
  297. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  298. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  299. batch
  300. );
  301. dim3 threads(BLOCKWIDTH);
  302.  
  303. AT_DISPATCH_FLOATING_TYPES(
  304. vec.type(), "vecquant4matmul_cuda", ([&] {
  305. VecQuant4MatMulKernel<<<blocks, threads>>>(
  306. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  307. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  308. batch, vec_height, height, width
  309. );
  310. })
  311. );
  312. }
  313.  
  314. template <typename scalar_t>
  315. __global__ void VecQuant4MatMulKernel(
  316. const scalar_t* __restrict__ vec,
  317. const int* __restrict__ mat,
  318. scalar_t* __restrict__ mul,
  319. const scalar_t* __restrict__ scales,
  320. const scalar_t* __restrict__ zeros,
  321. int batch,
  322. int vec_height,
  323. int height,
  324. int width
  325. ) {
  326. int b = blockIdx.z;
  327. int h = BLOCKHEIGHT4 * blockIdx.x;
  328. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  329.  
  330. __shared__ scalar_t blockvec[BLOCKWIDTH];
  331. blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x];
  332. __syncthreads();
  333.  
  334. scalar_t scale = scales[w];
  335. scalar_t zero = zeros[w];
  336.  
  337. scalar_t res = 0;
  338. int i = width * h + w;
  339. int k = 0;
  340.  
  341. unsigned int tmp;
  342.  
  343. while (k < BLOCKWIDTH) {
  344. tmp = as_unsigned(mat[i]);
  345. res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
  346. res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
  347. res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
  348. res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
  349. res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
  350. res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
  351. res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
  352. res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
  353. i += width;
  354. k += 8;
  355. }
  356.  
  357. atomicAdd(&mul[b * width + w], res);
  358. }
  359.  
  360. void vecquant8matmul_cuda(
  361. torch::Tensor vec,
  362. torch::Tensor mat,
  363. torch::Tensor mul,
  364. torch::Tensor scales,
  365. torch::Tensor zeros
  366. ) {
  367. int batch = vec.size(0);
  368. int vec_height = vec.size(1);
  369. int height = mat.size(0);
  370. int width = mat.size(1);
  371.  
  372. dim3 blocks(
  373. (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
  374. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  375. batch
  376. );
  377. dim3 threads(BLOCKWIDTH);
  378.  
  379. AT_DISPATCH_FLOATING_TYPES(
  380. vec.type(), "vecquant8matmul_cuda", ([&] {
  381. VecQuant8MatMulKernel<<<blocks, threads>>>(
  382. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  383. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  384. batch, vec_height, height, width
  385. );
  386. })
  387. );
  388. }
  389.  
  390. template <typename scalar_t>
  391. __global__ void VecQuant8MatMulKernel(
  392. const scalar_t* __restrict__ vec,
  393. const int* __restrict__ mat,
  394. scalar_t* __restrict__ mul,
  395. const scalar_t* __restrict__ scales,
  396. const scalar_t* __restrict__ zeros,
  397. int batch,
  398. int vec_height,
  399. int height,
  400. int width
  401. ) {
  402. int b = blockIdx.z;
  403. int h = BLOCKHEIGHT8 * blockIdx.x;
  404. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  405.  
  406. __shared__ scalar_t blockvec[BLOCKWIDTH];
  407. blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x];
  408. __syncthreads();
  409.  
  410. scalar_t scale = scales[w];
  411. scalar_t zero = zeros[w];
  412.  
  413. scalar_t res = 0;
  414. int i = width * h + w;
  415. int k = 0;
  416.  
  417. unsigned int tmp;
  418.  
  419. while (k < BLOCKWIDTH) {
  420. tmp = as_unsigned(mat[i]);
  421. res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
  422. res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
  423. res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
  424. res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
  425. i += width;
  426. k += 4;
  427. }
  428.  
  429. atomicAdd(&mul[b * width + w], res);
  430. }
  431.  
  432. template <typename scalar_t>
  433. __global__ void VecQuant4TransposeMatMulKernel(
  434. const scalar_t* __restrict__ vec,
  435. const int* __restrict__ mat,
  436. scalar_t* __restrict__ mul,
  437. const scalar_t* __restrict__ scales,
  438. const scalar_t* __restrict__ zeros,
  439. int batch,
  440. int vec_height,
  441. int height,
  442. int width
  443. ) {
  444. int b = blockIdx.z;
  445. int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
  446. unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
  447. int w = BLOCKWIDTH * blockIdx.y;
  448.  
  449. int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
  450. int n_cols = b;
  451.  
  452. __shared__ scalar_t blockvec[BLOCKWIDTH];
  453. blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x];
  454. __syncthreads();
  455.  
  456. scalar_t res = 0;
  457. int i = width * h + w;
  458. int k = 0;
  459. int j = w;
  460. unsigned int tmp;
  461. while (k < BLOCKWIDTH) {
  462. tmp = as_unsigned(mat[i]);
  463. res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k];
  464. i += 1;
  465. j += 1;
  466. k += 1;
  467. }
  468.  
  469. atomicAdd(&mul[n_cols * height * 8 + n_rows], res);
  470. }
  471.  
  472. void vecquant4transposematmul_cuda(
  473. torch::Tensor vec,
  474. torch::Tensor mat,
  475. torch::Tensor mul,
  476. torch::Tensor scales,
  477. torch::Tensor zeros
  478. ) {
  479. int batch = vec.size(0);
  480. int vec_height = vec.size(1);
  481. int height = mat.size(0);
  482. int width = mat.size(1);
  483.  
  484. dim3 blocks(
  485. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  486. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  487. batch
  488. );
  489. dim3 threads(BLOCKWIDTH);
  490.  
  491. AT_DISPATCH_FLOATING_TYPES(
  492. vec.type(), "vecquant4transposematmul_cuda", ([&] {
  493. VecQuant4TransposeMatMulKernel<<<blocks, threads>>>(
  494. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  495. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  496. batch, vec_height, height, width
  497. );
  498. })
  499. );
  500. }
  501.  
  502. template <typename scalar_t>
  503. __global__ void VecQuant4MatMulHalfKernel(
  504. const scalar_t* __restrict__ vec,
  505. const int* __restrict__ mat,
  506. scalar_t* __restrict__ mul,
  507. const scalar_t* __restrict__ scales,
  508. const scalar_t* __restrict__ zeros,
  509. int batch,
  510. int vec_height,
  511. int height,
  512. int width
  513. ) {
  514. int b = blockIdx.z;
  515. int h = BLOCKHEIGHT4 * blockIdx.x;
  516. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  517.  
  518. __shared__ __half blockvec[BLOCKWIDTH];
  519. blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
  520. __syncthreads();
  521.  
  522. __half scale = __half(scales[w]);
  523. __half zero = __half(zeros[w]);
  524.  
  525. __half res = __float2half(0.0f);
  526. int i = width * h + w;
  527. int k = 0;
  528.  
  529. unsigned int tmp;
  530.  
  531. while (k < BLOCKWIDTH) {
  532. tmp = as_unsigned(mat[i]);
  533. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
  534. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
  535. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
  536. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
  537. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
  538. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
  539. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
  540. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
  541. i += width;
  542. k += 8;
  543. }
  544.  
  545. __half* mul2 = (__half*)mul;
  546. atomicAdd2(&mul2[b * width + w], res);
  547. }
  548.  
  549. void vecquant4matmul_half_cuda(
  550. torch::Tensor vec,
  551. torch::Tensor mat,
  552. torch::Tensor mul,
  553. torch::Tensor scales,
  554. torch::Tensor zeros
  555. ) {
  556. int batch = vec.size(0);
  557. int vec_height = vec.size(1);
  558. int height = mat.size(0);
  559. int width = mat.size(1);
  560.  
  561. dim3 blocks(
  562. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  563. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  564. batch
  565. );
  566. dim3 threads(BLOCKWIDTH);
  567.  
  568. AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
  569. AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
  570. VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
  571. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  572. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  573. batch, vec_height, height, width
  574. );
  575. })
  576. ));
  577. }
  578.  
  579. template <typename scalar_t>
  580. __global__ void VecQuant4TransposeMatMulHalfKernel(
  581. const scalar_t* __restrict__ vec,
  582. const int* __restrict__ mat,
  583. scalar_t* __restrict__ mul,
  584. const scalar_t* __restrict__ scales,
  585. const scalar_t* __restrict__ zeros,
  586. int batch,
  587. int vec_height,
  588. int height,
  589. int width
  590. ) {
  591. int b = blockIdx.z;
  592. int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
  593. unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
  594. int w = BLOCKWIDTH * blockIdx.y;
  595.  
  596. int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
  597. int n_cols = b;
  598.  
  599. __shared__ __half blockvec[BLOCKWIDTH];
  600. blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
  601. __syncthreads();
  602.  
  603. __half res = __float2half(0.0f);
  604. int i = width * h + w;
  605. int k = 0;
  606. int j = w;
  607. unsigned int tmp;
  608. while (k < BLOCKWIDTH) {
  609. tmp = as_unsigned(mat[i]);
  610. __half zero = __half(zeros[j]);
  611. __half scale = __half(scales[j]);
  612. res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
  613. i += 1;
  614. j += 1;
  615. k += 1;
  616. }
  617.  
  618. __half* mul2 = (__half*)mul;
  619. atomicAdd2(&mul2[n_cols * height * 8 + n_rows], res);
  620. }
  621.  
  622. void vecquant4transposematmul_half_cuda(
  623. torch::Tensor vec,
  624. torch::Tensor mat,
  625. torch::Tensor mul,
  626. torch::Tensor scales,
  627. torch::Tensor zeros
  628. ) {
  629. int batch = vec.size(0);
  630. int vec_height = vec.size(1);
  631. int height = mat.size(0);
  632. int width = mat.size(1);
  633.  
  634. dim3 blocks(
  635. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  636. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  637. batch
  638. );
  639. dim3 threads(BLOCKWIDTH);
  640.  
  641. AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
  642. AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
  643. VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
  644. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  645. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  646. batch, vec_height, height, width
  647. );
  648. })
  649. ));
  650. }
  651.  
  652. template <typename scalar_t>
  653. __global__ void VecQuant4ReconsKernel(
  654. const int* __restrict__ mat,
  655. scalar_t* __restrict__ res,
  656. const scalar_t* __restrict__ scales,
  657. const scalar_t* __restrict__ zeros,
  658. int height,
  659. int width
  660. ) {
  661. int b = blockIdx.z;
  662. int h = BLOCKHEIGHT4 * blockIdx.x;
  663. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  664. int n_rows = h * 8 + b;
  665. int n_cols = w;
  666. scalar_t scale = scales[w];
  667. scalar_t zero = zeros[w];
  668. int i = width * h + width * (b / 8) + w;
  669. int shift = b % 8 * 4;
  670. unsigned int tmp = as_unsigned(mat[i]);
  671. scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
  672. res[n_rows * width + n_cols] = result;
  673. }
  674.  
  675. void vecquant4recons_cuda(
  676. torch::Tensor mat,
  677. torch::Tensor res,
  678. torch::Tensor scales,
  679. torch::Tensor zeros
  680. ) {
  681. int batch = BLOCKWIDTH;
  682. int height = mat.size(0);
  683. int width = mat.size(1);
  684.  
  685. dim3 blocks(
  686. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  687. (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
  688. batch
  689. );
  690. dim3 threads(BLOCKWIDTH);
  691.  
  692. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  693. scales.type(), "vecquant4recons_cuda", ([&] {
  694. VecQuant4ReconsKernel<<<blocks, threads>>>(
  695. mat.data<int>(), res.data<scalar_t>(),
  696. scales.data<scalar_t>(), zeros.data<scalar_t>(),
  697. height, width
  698. );
  699. })
  700. );
  701. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement