Advertisement
Guest User

Untitled

a guest
May 24th, 2015
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.10 KB | None | 0 0
  1. #define TILEX 4
  2. #define TILEX_SHIFT 2
  3. #define TILEY 4
  4. #define TILEY_SHIFT 2
  5.  
  6.  
  7. __kernel void mmmKernel(__global float4 *matrixA,
  8. __global float4 *matrixB,
  9. __global float4* matrixC,
  10. uint widthA, uint widthB)
  11. {
  12. int2 pos = (int2)(get_global_id(0), get_global_id(1));
  13.  
  14.  
  15. float4 sum0 = (float4)(0);
  16. float4 sum1 = (float4)(0);
  17. float4 sum2 = (float4)(0);
  18. float4 sum3 = (float4)(0);
  19.  
  20. /* Vectorization of input Matrices reduces their width by a factor of 4 */
  21. widthB /= 4;
  22.  
  23. for(int i = 0; i < widthA; i=i+4)
  24. {
  25. float4 tempA0 = matrixA[i/4 + (pos.y << TILEY_SHIFT) * (widthA / 4)];
  26. float4 tempA1 = matrixA[i/4 + ((pos.y << TILEY_SHIFT) + 1) * (widthA / 4)];
  27. float4 tempA2 = matrixA[i/4 + ((pos.y << TILEY_SHIFT) + 2) * (widthA / 4)];
  28. float4 tempA3 = matrixA[i/4 + ((pos.y << TILEY_SHIFT) + 3) * (widthA / 4)];
  29.  
  30. //Matrix B is not transposed
  31. float4 tempB0 = matrixB[pos.x + i * widthB];
  32. float4 tempB1 = matrixB[pos.x + (i + 1) * widthB];
  33. float4 tempB2 = matrixB[pos.x + (i + 2) * widthB];
  34. float4 tempB3 = matrixB[pos.x + (i + 3) * widthB];
  35.  
  36. sum0.x += tempA0.x * tempB0.x + tempA0.y * tempB1.x + tempA0.z * tempB2.x + tempA0.w * tempB3.x;
  37. sum0.y += tempA0.x * tempB0.y + tempA0.y * tempB1.y + tempA0.z * tempB2.y + tempA0.w * tempB3.y;
  38. sum0.z += tempA0.x * tempB0.z + tempA0.y * tempB1.z + tempA0.z * tempB2.z + tempA0.w * tempB3.z;
  39. sum0.w += tempA0.x * tempB0.w + tempA0.y * tempB1.w + tempA0.z * tempB2.w + tempA0.w * tempB3.w;
  40.  
  41. sum1.x += tempA1.x * tempB0.x + tempA1.y * tempB1.x + tempA1.z * tempB2.x + tempA1.w * tempB3.x;
  42. sum1.y += tempA1.x * tempB0.y + tempA1.y * tempB1.y + tempA1.z * tempB2.y + tempA1.w * tempB3.y;
  43. sum1.z += tempA1.x * tempB0.z + tempA1.y * tempB1.z + tempA1.z * tempB2.z + tempA1.w * tempB3.z;
  44. sum1.w += tempA1.x * tempB0.w + tempA1.y * tempB1.w + tempA1.z * tempB2.w + tempA1.w * tempB3.w;
  45.  
  46. sum2.x += tempA2.x * tempB0.x + tempA2.y * tempB1.x + tempA2.z * tempB2.x + tempA2.w * tempB3.x;
  47. sum2.y += tempA2.x * tempB0.y + tempA2.y * tempB1.y + tempA2.z * tempB2.y + tempA2.w * tempB3.y;
  48. sum2.z += tempA2.x * tempB0.z + tempA2.y * tempB1.z + tempA2.z * tempB2.z + tempA2.w * tempB3.z;
  49. sum2.w += tempA2.x * tempB0.w + tempA2.y * tempB1.w + tempA2.z * tempB2.w + tempA2.w * tempB3.w;
  50.  
  51. sum3.x += tempA3.x * tempB0.x + tempA3.y * tempB1.x + tempA3.z * tempB2.x + tempA3.w * tempB3.x;
  52. sum3.y += tempA3.x * tempB0.y + tempA3.y * tempB1.y + tempA3.z * tempB2.y + tempA3.w * tempB3.y;
  53. sum3.z += tempA3.x * tempB0.z + tempA3.y * tempB1.z + tempA3.z * tempB2.z + tempA3.w * tempB3.z;
  54. sum3.w += tempA3.x * tempB0.w + tempA3.y * tempB1.w + tempA3.z * tempB2.w + tempA3.w * tempB3.w;
  55. }
  56. matrixC[pos.x + ((pos.y << TILEY_SHIFT) + 0) * widthB] = sum0;
  57. matrixC[pos.x + ((pos.y << TILEY_SHIFT) + 1) * widthB] = sum1;
  58. matrixC[pos.x + ((pos.y << TILEY_SHIFT) + 2) * widthB] = sum2;
  59. matrixC[pos.x + ((pos.y << TILEY_SHIFT) + 3) * widthB] = sum3;
  60. }
  61.  
  62.  
  63. /* Matrix A is cached into local memory block */
  64. /* Required global threads = (widthC / 4, heightC / 4) */
  65. __kernel void mmmKernel_local(__global float4 *matrixA,
  66. __global float4 *matrixB,
  67. __global float4* matrixC,
  68. int widthA,
  69. __local float4 *blockA)
  70. {
  71. int blockPos = get_local_id(0) + get_local_size(0) * (get_local_id(1) << TILEY_SHIFT); //Should be : localId * (TILEX / 4) (float4)
  72.  
  73. /* Position of thread will be according to the number of values it writes i.e TILE size */
  74. int globalPos = get_global_id(0) + (get_global_id(1) << TILEY_SHIFT) * get_global_size(0);
  75.  
  76. /* Each thread writes 4 float4s */
  77. float4 sum0 = (float4)(0);
  78. float4 sum1 = (float4)(0);
  79. float4 sum2 = (float4)(0);
  80. float4 sum3 = (float4)(0);
  81.  
  82. int temp = widthA / 4;
  83.  
  84. /* This loop runs for number of blocks of A in horizontal direction */
  85. for(int i = 0; i < (temp / get_local_size(0)); i++)
  86. {
  87. /* Calculate global ids of threads from the particular block to load from matrix A depending on i */
  88. int globalPosA = i * get_local_size(0) + get_local_id(0) + (get_global_id(1) << TILEY_SHIFT) * temp;
  89.  
  90. /* Load values in blockA from matrixA */
  91. blockA[blockPos] = matrixA[globalPosA];
  92. blockA[blockPos + get_local_size(0)] = matrixA[globalPosA + temp];
  93. blockA[blockPos + 2 * get_local_size(0)] = matrixA[globalPosA + 2 * temp];
  94. blockA[blockPos + 3 * get_local_size(0)] = matrixA[globalPosA + 3 * temp];
  95.  
  96. barrier(CLK_LOCAL_MEM_FENCE);
  97.  
  98. /* Calculate global ids of threads from the particular block to load from matrix B depending on i */
  99. int globalPosB = get_global_id(0) + ((i * get_local_size(0)) << TILEY_SHIFT) * get_global_size(0);
  100.  
  101. /* This loop runs for number of threads in horizontal direction in the block of A */
  102. for(int j = 0; j < get_local_size(0) * 4; j=j+4)
  103. {
  104. /* Load 4 float4s from blockA : access patters = strided from local memory */
  105. float4 tempA0 = blockA[(j >> 2) + get_local_id(1) * TILEY * get_local_size(0)];
  106. float4 tempA1 = blockA[(j >> 2) + (get_local_id(1) * TILEY + 1) * get_local_size(0)];
  107. float4 tempA2 = blockA[(j >> 2) + (get_local_id(1) * TILEY + 2) * get_local_size(0)];
  108. float4 tempA3 = blockA[(j >> 2) + (get_local_id(1) * TILEY + 3) * get_local_size(0)];
  109.  
  110. /* Load corresponding values from matrixB, access pattern = linear from global memory */
  111. float4 tempB0 = matrixB[globalPosB + j * get_global_size(0)]; //Should be localId.x * (TILEX / 4)
  112. float4 tempB1 = matrixB[globalPosB + (j + 1) * get_global_size(0)];
  113. float4 tempB2 = matrixB[globalPosB + (j + 2) * get_global_size(0)];
  114. float4 tempB3 = matrixB[globalPosB + (j + 3) * get_global_size(0)];
  115.  
  116. sum0.x += tempA0.x * tempB0.x + tempA0.y * tempB1.x + tempA0.z * tempB2.x + tempA0.w * tempB3.x;
  117. sum0.y += tempA0.x * tempB0.y + tempA0.y * tempB1.y + tempA0.z * tempB2.y + tempA0.w * tempB3.y;
  118. sum0.z += tempA0.x * tempB0.z + tempA0.y * tempB1.z + tempA0.z * tempB2.z + tempA0.w * tempB3.z;
  119. sum0.w += tempA0.x * tempB0.w + tempA0.y * tempB1.w + tempA0.z * tempB2.w + tempA0.w * tempB3.w;
  120.  
  121. sum1.x += tempA1.x * tempB0.x + tempA1.y * tempB1.x + tempA1.z * tempB2.x + tempA1.w * tempB3.x;
  122. sum1.y += tempA1.x * tempB0.y + tempA1.y * tempB1.y + tempA1.z * tempB2.y + tempA1.w * tempB3.y;
  123. sum1.z += tempA1.x * tempB0.z + tempA1.y * tempB1.z + tempA1.z * tempB2.z + tempA1.w * tempB3.z;
  124. sum1.w += tempA1.x * tempB0.w + tempA1.y * tempB1.w + tempA1.z * tempB2.w + tempA1.w * tempB3.w;
  125.  
  126. sum2.x += tempA2.x * tempB0.x + tempA2.y * tempB1.x + tempA2.z * tempB2.x + tempA2.w * tempB3.x;
  127. sum2.y += tempA2.x * tempB0.y + tempA2.y * tempB1.y + tempA2.z * tempB2.y + tempA2.w * tempB3.y;
  128. sum2.z += tempA2.x * tempB0.z + tempA2.y * tempB1.z + tempA2.z * tempB2.z + tempA2.w * tempB3.z;
  129. sum2.w += tempA2.x * tempB0.w + tempA2.y * tempB1.w + tempA2.z * tempB2.w + tempA2.w * tempB3.w;
  130.  
  131. sum3.x += tempA3.x * tempB0.x + tempA3.y * tempB1.x + tempA3.z * tempB2.x + tempA3.w * tempB3.x;
  132. sum3.y += tempA3.x * tempB0.y + tempA3.y * tempB1.y + tempA3.z * tempB2.y + tempA3.w * tempB3.y;
  133. sum3.z += tempA3.x * tempB0.z + tempA3.y * tempB1.z + tempA3.z * tempB2.z + tempA3.w * tempB3.z;
  134. sum3.w += tempA3.x * tempB0.w + tempA3.y * tempB1.w + tempA3.z * tempB2.w + tempA3.w * tempB3.w;
  135.  
  136. }
  137. barrier(CLK_LOCAL_MEM_FENCE);
  138. }
  139. /* Write 16 values to matrixC */
  140. matrixC[globalPos] = sum0;
  141. matrixC[globalPos + get_global_size(0)] = sum1;
  142. matrixC[globalPos + 2 * get_global_size(0)] = sum2;
  143. matrixC[globalPos + 3 * get_global_size(0)] = sum3;
  144.  
  145. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement