Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // Macierze są pamiętane wierszami, a więc:
- // M(row, col) = *(M.elements + row * M.width + col)
- typedef struct {
- int width;
- int height;
- float *elements;
- } Matrix;
- // definiujemy rozmiar bloku wątków:
- #define BLOCK_SIZE 16
- // prototyp funkcji mnożącej (kernela)
- __global__ void MatMulKernel(const Matrix, const Matrix, Matrix);
- // Zakładamy (dla uproszczenia rozważań), że wymiary macierzy są
- // całkowitymi wielokrotnościami wartości BLOCK_SIZE
- // Funkcja mnożąca
- void MatMul(const Matrix A, const Matrix B, Matrix C)
- {
- // kopiujemy macierze A i B to globalnej pamięci urządzenia
- // najpierw A
- Matrix d_A;
- d_A.width = A.width;
- d_A.height = A.height;
- size_t size = A.width * A.height * sizeof(float);
- cudaMalloc((void **)&d_A.elements, size);
- cudaMemcpy(d_A.elements, A.elements, size, cudaMemcpyHostToDevice);
- // potem B
- Matrix d_B;
- d_B.width = B.width;
- d_B.height = B.height;
- size = B.width * B.height * sizeof(float);
- cudaMalloc((void **)&d_B.elements, size);
- cudaMemcpy(d_B.elements, B.elements, size,cudaMemcpyHostToDevice);
- // przydzielamy macierz C w globalnej pamięci urządzenia
- Matrix d_C;
- d_C.width = C.width;
- d_C.height = C.height;
- size = C.width * C.height * sizeof(float);
- cudaMalloc((void **)&d_C.elements, size);
- // preparujemy środowisko i wywołujemy kernel
- dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
- dim3 dimGrid(B.width / dimBlock.x, A.height / dimBlock.y);
- MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
- // odbieramy obliczoną macierz C z pamięci globalnej urządzenia
- cudaMemcpy(C.elements, d_C.elements, size, cudaMemcpyDeviceToHost);
- // zwalniamy pamięć
- cudaFree(d_A.elements);
- cudaFree(d_B.elements);
- cudaFree(d_C.elements);
- }
- // kernel odpowiedzialny za wymnożenie macierzy
- __global__ void MatMulKernel(Matrix A, Matrix B, Matrix C)
- {
- // każdy wątek oblicza jeden element macierzy C
- // akumulując wynik w zmiennej Cvalue
- float Cvalue = 0;
- int row = blockIdx.y * blockDim.y + threadIdx.y;
- int col = blockIdx.x * blockDim.x + threadIdx.x;
- for (int e = 0; e < A.width; ++e)
- Cvalue += A.elements[row * A.width + e]
- * B.elements[e * B.width + col];
- C.elements[row * C.width + col] = Cvalue;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement