Advertisement
Guest User

Untitled

a guest
Jul 30th, 2014
240
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.07 KB | None | 0 0
  1. __global__ static
  2. void matrixMultiply(int wA, //Acol
  3. int hA, //Arow
  4. int wB, //Bcol
  5. float *objects, //matrix A
  6. float *deviceClusters, // matrix B
  7. float *deviceC) // matrix C
  8. {
  9. const int WARP = 32;
  10. const int SHIFT = 5;
  11.  
  12. __shared__ volatile float AS[WARP][WARP];
  13. __shared__ volatile float BS[WARP][WARP];
  14.  
  15. // Block index
  16. int bx = blockIdx.x;
  17. int by = blockIdx.y;
  18.  
  19. // Thread index
  20. int tx = threadIdx.x;
  21. int ty = threadIdx.y;
  22.  
  23. // Index of the first sub-matrix of A processed by the block
  24. int aBegin = wA * WARP * by;
  25. // Index of the last sub-matrix of A processed by the block
  26. int aEnd = aBegin + ((wA+WARP-1)>>SHIFT)*WARP;
  27. // Step size used to iterate through the sub-matrices of A
  28. int aStep = WARP;
  29. // Index of the first sub-matrix of B processed by the block
  30. int bBegin = WARP * bx;
  31. int bStep = WARP * wB;
  32. int bRow = 0;
  33.  
  34. float cSub = 0.0f;
  35. //init shared memory
  36. AS[ty][tx] = 0.0f;
  37. BS[ty][tx] = 0.0f;
  38.  
  39. __syncthreads(); //sychnozized for all the data being loaded into shared memory
  40.  
  41. int a = 0,b=0;
  42. for (a = aBegin,b = bBegin;a<aEnd;a+=aStep,b+=bStep,bRow+=WARP)
  43. {
  44. if (a+tx<aBegin+wA && by*WARP+ty<hA)
  45. AS[ty][tx] = objects[a+wA*ty+tx];
  46. if (bRow<wA && bx*WARP+tx<wB)
  47. BS[ty][tx] = deviceClusters[b+wB*ty+tx];
  48.  
  49.  
  50. __syncthreads(); //sychnozized for all the data being loaded into shared memory
  51.  
  52. for (int i=0;i<WARP;i++)
  53. {
  54. float x = AS[ty][i];
  55. float y = BS[i][tx];
  56. cSub += (x-y)*(x-y);
  57. }
  58.  
  59. AS[ty][tx] = 0.0f;
  60. BS[ty][tx] = 0.0f;
  61. __syncthreads(); //sychnozized for all the data being loaded into shared memory
  62. }
  63.  
  64. //save the result of C matrix
  65. int c = by * wB * WARP + ty * wB;
  66. int c_col = bx * WARP + tx;
  67.  
  68. if (by*WARP+ty<hA && c_col < wB)
  69. deviceC[c+c_col] = cSub;
  70. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement