Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <torch/all.h>
- #include <torch/python.h>
- #include <cuda.h>
- #include <cuda_runtime.h>
- #include <cuda_fp16.h>
- // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
- __device__ __forceinline__ void atomicAdd2(__half* address, c10::Half val) {
- unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
- unsigned int old = *address_as_ui;
- unsigned int assumed;
- do {
- assumed = old;
- unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
- hsum += val;
- old = reinterpret_cast<size_t>(address) & 2
- ? (old & 0xffff) | (hsum << 16)
- : (old & 0xffff0000) | hsum;
- old = atomicCAS(address_as_ui, assumed, old);
- // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
- } while (assumed != old);
- }
- template <typename scalar_t>
- __global__ void VecQuant2MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- );
- template <typename scalar_t>
- __global__ void VecQuant3MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- );
- template <typename scalar_t>
- __global__ void VecQuant4MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- );
- template <typename scalar_t>
- __global__ void VecQuant8MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- );
- const int BLOCKWIDTH = 256;
- const int BLOCKHEIGHT2 = 16;
- const int BLOCKHEIGHT3 = 24;
- const int BLOCKHEIGHT4 = 32;
- const int BLOCKHEIGHT8 = 64;
- __device__ inline unsigned int as_unsigned(int i) {
- return *reinterpret_cast<unsigned int*>(&i);
- }
- void vecquant2matmul_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_FLOATING_TYPES(
- vec.type(), "vecquant2matmul_cuda", ([&] {
- VecQuant2MatMulKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- );
- }
- template <typename scalar_t>
- __global__ void VecQuant2MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT2 * blockIdx.x;
- int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- __shared__ scalar_t blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x];
- __syncthreads();
- scalar_t scale = scales[w];
- scalar_t zero = zeros[w];
- scalar_t res = 0;
- int i = width * h + w;
- int k = 0;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp = as_unsigned(mat[i]);
- res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
- res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
- res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
- res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
- res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
- res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
- res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
- res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
- res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
- res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
- res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
- res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
- res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
- res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
- res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
- res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
- i += width;
- k += 16;
- }
- atomicAdd(&mul[b * width + w], res);
- }
- void vecquant3matmul_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_FLOATING_TYPES(
- vec.type(), "vecquant3matmul_cuda", ([&] {
- VecQuant3MatMulKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- );
- }
- template <typename scalar_t>
- __global__ void VecQuant3MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT3 * blockIdx.x;
- int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- __shared__ scalar_t blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x];
- __syncthreads();
- scalar_t scale = scales[w];
- scalar_t zero = zeros[w];
- scalar_t res = 0;
- int i = width * h + w;
- int k = 0;
- unsigned int tmp1;
- unsigned int tmp2;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp1 = as_unsigned(mat[i]);
- res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
- res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
- res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
- res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
- res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
- res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
- res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
- res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
- res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
- res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
- i += width;
- tmp2 = as_unsigned(mat[i]);
- tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
- tmp2 >>= 1;
- res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
- k += 11;
- res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
- res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
- res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
- res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
- res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
- res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
- res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
- res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
- res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
- res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
- i += width;
- tmp1 = as_unsigned(mat[i]);
- tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
- tmp1 >>= 2;
- res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
- k += 11;
- res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
- res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
- res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
- res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
- res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
- res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
- res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
- res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
- res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
- res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
- i += width;
- k += 10;
- }
- atomicAdd(&mul[b * width + w], res);
- }
- void vecquant4matmul_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_FLOATING_TYPES(
- vec.type(), "vecquant4matmul_cuda", ([&] {
- VecQuant4MatMulKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- );
- }
- template <typename scalar_t>
- __global__ void VecQuant4MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT4 * blockIdx.x;
- int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- __shared__ scalar_t blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x];
- __syncthreads();
- scalar_t scale = scales[w];
- scalar_t zero = zeros[w];
- scalar_t res = 0;
- int i = width * h + w;
- int k = 0;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp = as_unsigned(mat[i]);
- res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
- res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
- res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
- res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
- res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
- res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
- res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
- res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
- i += width;
- k += 8;
- }
- atomicAdd(&mul[b * width + w], res);
- }
- void vecquant8matmul_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_FLOATING_TYPES(
- vec.type(), "vecquant8matmul_cuda", ([&] {
- VecQuant8MatMulKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- );
- }
- template <typename scalar_t>
- __global__ void VecQuant8MatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT8 * blockIdx.x;
- int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- __shared__ scalar_t blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x];
- __syncthreads();
- scalar_t scale = scales[w];
- scalar_t zero = zeros[w];
- scalar_t res = 0;
- int i = width * h + w;
- int k = 0;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp = as_unsigned(mat[i]);
- res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
- res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
- res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
- res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
- i += width;
- k += 4;
- }
- atomicAdd(&mul[b * width + w], res);
- }
- template <typename scalar_t>
- __global__ void VecQuant4TransposeMatMulKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
- unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
- int w = BLOCKWIDTH * blockIdx.y;
- int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
- int n_cols = b;
- __shared__ scalar_t blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x];
- __syncthreads();
- scalar_t res = 0;
- int i = width * h + w;
- int k = 0;
- int j = w;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp = as_unsigned(mat[i]);
- res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k];
- i += 1;
- j += 1;
- k += 1;
- }
- atomicAdd(&mul[n_cols * height * 8 + n_rows], res);
- }
- void vecquant4transposematmul_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_FLOATING_TYPES(
- vec.type(), "vecquant4transposematmul_cuda", ([&] {
- VecQuant4TransposeMatMulKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- );
- }
- template <typename scalar_t>
- __global__ void VecQuant4MatMulHalfKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT4 * blockIdx.x;
- int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- __shared__ __half blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
- __syncthreads();
- __half scale = __half(scales[w]);
- __half zero = __half(zeros[w]);
- __half res = __float2half(0.0f);
- int i = width * h + w;
- int k = 0;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp = as_unsigned(mat[i]);
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
- i += width;
- k += 8;
- }
- __half* mul2 = (__half*)mul;
- atomicAdd2(&mul2[b * width + w], res);
- }
- void vecquant4matmul_half_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
- AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
- VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- ));
- }
- template <typename scalar_t>
- __global__ void VecQuant4TransposeMatMulHalfKernel(
- const scalar_t* __restrict__ vec,
- const int* __restrict__ mat,
- scalar_t* __restrict__ mul,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int batch,
- int vec_height,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
- unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
- int w = BLOCKWIDTH * blockIdx.y;
- int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
- int n_cols = b;
- __shared__ __half blockvec[BLOCKWIDTH];
- blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
- __syncthreads();
- __half res = __float2half(0.0f);
- int i = width * h + w;
- int k = 0;
- int j = w;
- unsigned int tmp;
- while (k < BLOCKWIDTH) {
- tmp = as_unsigned(mat[i]);
- __half zero = __half(zeros[j]);
- __half scale = __half(scales[j]);
- res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
- i += 1;
- j += 1;
- k += 1;
- }
- __half* mul2 = (__half*)mul;
- atomicAdd2(&mul2[n_cols * height * 8 + n_rows], res);
- }
- void vecquant4transposematmul_half_cuda(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
- AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
- VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
- vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- batch, vec_height, height, width
- );
- })
- ));
- }
- template <typename scalar_t>
- __global__ void VecQuant4ReconsKernel(
- const int* __restrict__ mat,
- scalar_t* __restrict__ res,
- const scalar_t* __restrict__ scales,
- const scalar_t* __restrict__ zeros,
- int height,
- int width
- ) {
- int b = blockIdx.z;
- int h = BLOCKHEIGHT4 * blockIdx.x;
- int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- int n_rows = h * 8 + b;
- int n_cols = w;
- scalar_t scale = scales[w];
- scalar_t zero = zeros[w];
- int i = width * h + width * (b / 8) + w;
- int shift = b % 8 * 4;
- unsigned int tmp = as_unsigned(mat[i]);
- scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
- res[n_rows * width + n_cols] = result;
- }
- void vecquant4recons_cuda(
- torch::Tensor mat,
- torch::Tensor res,
- torch::Tensor scales,
- torch::Tensor zeros
- ) {
- int batch = BLOCKWIDTH;
- int height = mat.size(0);
- int width = mat.size(1);
- dim3 blocks(
- (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
- batch
- );
- dim3 threads(BLOCKWIDTH);
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
- scales.type(), "vecquant4recons_cuda", ([&] {
- VecQuant4ReconsKernel<<<blocks, threads>>>(
- mat.data<int>(), res.data<scalar_t>(),
- scales.data<scalar_t>(), zeros.data<scalar_t>(),
- height, width
- );
- })
- );
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement