SHARE
TWEET

Untitled

a guest Jan 12th, 2018 46 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. /**********************************************************************
  2. Copyright ゥ2015 Advanced Micro Devices, Inc. All rights reserved.
  3.  
  4. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
  5.  
  6. ・Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
  7. ・Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or
  8.  other materials provided with the distribution.
  9.  
  10. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  11.  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
  12.  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
  13.  OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  14.  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  15. ********************************************************************/
  16.  
  17. #define TILEX 4
  18. #define TILEX_SHIFT 2
  19. #define TILEY 4
  20. #define TILEY_SHIFT 2
  21.  
  22. /* Output tile size : 4x4 = Each thread computes 16 float values*/
  23. /* Required global threads = (widthC / 4, heightC / 4) */
  24. /* This kernel runs on 7xx and CPU as they don't have hardware local memory */
  25. template <class T>
  26. kernel void mmmKernel(global T * matrixA,
  27.                       global T * matrixB,
  28.                       global T * matrixC,
  29.                       uint widthA,
  30.                       uint widthB)
  31. {
  32.     int2 pos = (int2)(get_global_id(0), get_global_id(1));
  33.  
  34.  
  35.     T sum0 = (T)(0);
  36.     T sum1 = (T)(0);
  37.     T sum2 = (T)(0);
  38.     T sum3 = (T)(0);
  39.  
  40.     /* Vectorization of input Matrices reduces their width by a factor of 4 */
  41.     widthB /= 4;
  42.  
  43.     for(int i = 0; i < widthA; i=i+4)
  44.     {
  45.         T tempA0 = matrixA[i/4 + (pos.y << TILEY_SHIFT) * (widthA / 4)];
  46.         T tempA1 = matrixA[i/4 + ((pos.y << TILEY_SHIFT) + 1) * (widthA / 4)];
  47.         T tempA2 = matrixA[i/4 + ((pos.y << TILEY_SHIFT) + 2) * (widthA / 4)];
  48.         T tempA3 = matrixA[i/4 + ((pos.y << TILEY_SHIFT) + 3) * (widthA / 4)];
  49.  
  50.         //Matrix B is not transposed
  51.         T tempB0 = matrixB[pos.x + i * widthB];
  52.         T tempB1 = matrixB[pos.x + (i + 1) * widthB];
  53.         T tempB2 = matrixB[pos.x + (i + 2) * widthB];
  54.         T tempB3 = matrixB[pos.x + (i + 3) * widthB];
  55.  
  56.         sum0.x += tempA0.x * tempB0.x + tempA0.y * tempB1.x + tempA0.z * tempB2.x + tempA0.w * tempB3.x;
  57.         sum0.y += tempA0.x * tempB0.y + tempA0.y * tempB1.y + tempA0.z * tempB2.y + tempA0.w * tempB3.y;
  58.         sum0.z += tempA0.x * tempB0.z + tempA0.y * tempB1.z + tempA0.z * tempB2.z + tempA0.w * tempB3.z;
  59.         sum0.w += tempA0.x * tempB0.w + tempA0.y * tempB1.w + tempA0.z * tempB2.w + tempA0.w * tempB3.w;
  60.  
  61.         sum1.x += tempA1.x * tempB0.x + tempA1.y * tempB1.x + tempA1.z * tempB2.x + tempA1.w * tempB3.x;
  62.         sum1.y += tempA1.x * tempB0.y + tempA1.y * tempB1.y + tempA1.z * tempB2.y + tempA1.w * tempB3.y;
  63.         sum1.z += tempA1.x * tempB0.z + tempA1.y * tempB1.z + tempA1.z * tempB2.z + tempA1.w * tempB3.z;
  64.         sum1.w += tempA1.x * tempB0.w + tempA1.y * tempB1.w + tempA1.z * tempB2.w + tempA1.w * tempB3.w;
  65.  
  66.         sum2.x += tempA2.x * tempB0.x + tempA2.y * tempB1.x + tempA2.z * tempB2.x + tempA2.w * tempB3.x;
  67.         sum2.y += tempA2.x * tempB0.y + tempA2.y * tempB1.y + tempA2.z * tempB2.y + tempA2.w * tempB3.y;
  68.         sum2.z += tempA2.x * tempB0.z + tempA2.y * tempB1.z + tempA2.z * tempB2.z + tempA2.w * tempB3.z;
  69.         sum2.w += tempA2.x * tempB0.w + tempA2.y * tempB1.w + tempA2.z * tempB2.w + tempA2.w * tempB3.w;
  70.  
  71.         sum3.x += tempA3.x * tempB0.x + tempA3.y * tempB1.x + tempA3.z * tempB2.x + tempA3.w * tempB3.x;
  72.         sum3.y += tempA3.x * tempB0.y + tempA3.y * tempB1.y + tempA3.z * tempB2.y + tempA3.w * tempB3.y;
  73.         sum3.z += tempA3.x * tempB0.z + tempA3.y * tempB1.z + tempA3.z * tempB2.z + tempA3.w * tempB3.z;
  74.         sum3.w += tempA3.x * tempB0.w + tempA3.y * tempB1.w + tempA3.z * tempB2.w + tempA3.w * tempB3.w;
  75.     }
  76.     matrixC[pos.x + ((pos.y <<  TILEY_SHIFT) + 0) * widthB] = sum0;
  77.     matrixC[pos.x + ((pos.y <<  TILEY_SHIFT) + 1) * widthB] = sum1;
  78.     matrixC[pos.x + ((pos.y <<  TILEY_SHIFT) + 2) * widthB] = sum2;
  79.     matrixC[pos.x + ((pos.y <<  TILEY_SHIFT) + 3) * widthB] = sum3;
  80. }
  81.  
  82. template __attribute__((mangled_name(mmmKernelFloat4)))
  83. kernel void mmmKernel(__global float4 *matrixA,
  84.                       __global float4 *matrixB,
  85.                       __global float4* matrixC,
  86.                       uint widthA,
  87.                       uint widthB);
  88. template __attribute__((mangled_name(mmmKernelInt4)))
  89. kernel void mmmKernel(__global int4 *matrixA,  
  90.                       __global int4 *matrixB,  
  91.                       __global int4* matrixC,
  92.                       uint widthA,
  93.                       uint widthB);
  94.  
  95.  
  96. /* Matrix A is cached into local memory block */
  97. /* Required global threads = (widthC / 4, heightC / 4) */
  98. template <class T>
  99. kernel void mmmKernel_local(global T * matrixA,
  100.                             global T * matrixB,
  101.                             global T * matrixC,
  102.                             int widthA,
  103.                             __local T *blockA)
  104.  
  105. {
  106.     int blockPos = get_local_id(0) + get_local_size(0) * (get_local_id(1) << TILEY_SHIFT); //Should be : localId * (TILEX / 4) (int4)
  107.  
  108.     /* Position of thread will be according to the number of values it writes i.e TILE size */
  109.     int globalPos =  get_global_id(0) + (get_global_id(1) << TILEY_SHIFT) * get_global_size(0);
  110.  
  111.     /* Each thread writes 4 int4s */
  112.     T sum0 = (T)(0);
  113.     T sum1 = (T)(0);
  114.     T sum2 = (T)(0);
  115.     T sum3 = (T)(0);
  116.  
  117.     int temp = widthA / 4;
  118.  
  119.     /* This loop runs for number of blocks of A in horizontal direction */
  120.     for(int i = 0; i < (temp / get_local_size(0)); i++)
  121.     {
  122.         /* Calculate global ids of threads from the particular block to load from matrix A depending on i */
  123.         int globalPosA = i * get_local_size(0) + get_local_id(0) + (get_global_id(1) << TILEY_SHIFT) * temp;
  124.  
  125.         /* Load values in blockA from matrixA */
  126.         blockA[blockPos] =                          matrixA[globalPosA];
  127.         blockA[blockPos + get_local_size(0)] =      matrixA[globalPosA + temp];
  128.         blockA[blockPos + 2 * get_local_size(0)] =  matrixA[globalPosA + 2 * temp];
  129.         blockA[blockPos + 3 * get_local_size(0)] =  matrixA[globalPosA + 3 * temp];
  130.  
  131.         barrier(CLK_LOCAL_MEM_FENCE);
  132.  
  133.         /* Calculate global ids of threads from the particular block to load from matrix B depending on i */
  134.         int globalPosB = get_global_id(0) + ((i * get_local_size(0)) << TILEY_SHIFT) * get_global_size(0);
  135.  
  136.         /* This loop runs for number of threads in horizontal direction in the block of A */
  137.         for(int j = 0; j < get_local_size(0) * 4; j=j+4)
  138.         {
  139.             /* Load 4 int4s from blockA : access patters = strided from local memory */
  140.             T tempA0 = blockA[(j >> 2) + get_local_id(1) * TILEY * get_local_size(0)];
  141.             T tempA1 = blockA[(j >> 2) + (get_local_id(1) * TILEY + 1) * get_local_size(0)];
  142.             T tempA2 = blockA[(j >> 2) + (get_local_id(1) * TILEY + 2) * get_local_size(0)];
  143.             T tempA3 = blockA[(j >> 2) + (get_local_id(1) * TILEY + 3) * get_local_size(0)];
  144.  
  145.             /* Load corresponding values from matrixB, access pattern = linear from global memory */
  146.             T tempB0 = matrixB[globalPosB  + j *  get_global_size(0)]; //Should be localId.x * (TILEX / 4)
  147.             T tempB1 = matrixB[globalPosB  + (j + 1) * get_global_size(0)];
  148.             T tempB2 = matrixB[globalPosB  + (j + 2) * get_global_size(0)];
  149.             T tempB3 = matrixB[globalPosB  + (j + 3) * get_global_size(0)];
  150.  
  151.             sum0.x += tempA0.x * tempB0.x + tempA0.y * tempB1.x + tempA0.z * tempB2.x + tempA0.w * tempB3.x;
  152.             sum0.y += tempA0.x * tempB0.y + tempA0.y * tempB1.y + tempA0.z * tempB2.y + tempA0.w * tempB3.y;
  153.             sum0.z += tempA0.x * tempB0.z + tempA0.y * tempB1.z + tempA0.z * tempB2.z + tempA0.w * tempB3.z;
  154.             sum0.w += tempA0.x * tempB0.w + tempA0.y * tempB1.w + tempA0.z * tempB2.w + tempA0.w * tempB3.w;
  155.  
  156.             sum1.x += tempA1.x * tempB0.x + tempA1.y * tempB1.x + tempA1.z * tempB2.x + tempA1.w * tempB3.x;
  157.             sum1.y += tempA1.x * tempB0.y + tempA1.y * tempB1.y + tempA1.z * tempB2.y + tempA1.w * tempB3.y;
  158.             sum1.z += tempA1.x * tempB0.z + tempA1.y * tempB1.z + tempA1.z * tempB2.z + tempA1.w * tempB3.z;
  159.             sum1.w += tempA1.x * tempB0.w + tempA1.y * tempB1.w + tempA1.z * tempB2.w + tempA1.w * tempB3.w;
  160.  
  161.             sum2.x += tempA2.x * tempB0.x + tempA2.y * tempB1.x + tempA2.z * tempB2.x + tempA2.w * tempB3.x;
  162.             sum2.y += tempA2.x * tempB0.y + tempA2.y * tempB1.y + tempA2.z * tempB2.y + tempA2.w * tempB3.y;
  163.             sum2.z += tempA2.x * tempB0.z + tempA2.y * tempB1.z + tempA2.z * tempB2.z + tempA2.w * tempB3.z;
  164.             sum2.w += tempA2.x * tempB0.w + tempA2.y * tempB1.w + tempA2.z * tempB2.w + tempA2.w * tempB3.w;
  165.  
  166.             sum3.x += tempA3.x * tempB0.x + tempA3.y * tempB1.x + tempA3.z * tempB2.x + tempA3.w * tempB3.x;
  167.             sum3.y += tempA3.x * tempB0.y + tempA3.y * tempB1.y + tempA3.z * tempB2.y + tempA3.w * tempB3.y;
  168.             sum3.z += tempA3.x * tempB0.z + tempA3.y * tempB1.z + tempA3.z * tempB2.z + tempA3.w * tempB3.z;
  169.             sum3.w += tempA3.x * tempB0.w + tempA3.y * tempB1.w + tempA3.z * tempB2.w + tempA3.w * tempB3.w;
  170.  
  171.         }
  172.         barrier(CLK_LOCAL_MEM_FENCE);
  173.     }
  174.     /* Write 16 values to matrixC */
  175.     matrixC[globalPos] = sum0;
  176.     matrixC[globalPos +  get_global_size(0)] = sum1;
  177.     matrixC[globalPos +  2 * get_global_size(0)] = sum2;
  178.     matrixC[globalPos +  3 * get_global_size(0)] = sum3;
  179.  
  180. }
  181.  
  182. template __attribute__((mangled_name(mmmKernel_localFloat4)))
  183. kernel void mmmKernel_local(__global float4 *matrixA,  
  184.                             __global float4 *matrixB,  
  185.                             __global float4* matrixC,
  186.                             int widthA,
  187.                             __local float4 *blockA);
  188. template __attribute__((mangled_name(mmmKernel_localInt4)))
  189. kernel void mmmKernel_local(__global int4 *matrixA,
  190.                             __global int4 *matrixB,
  191.                             __global int4* matrixC,
  192.                             int widthA,
  193.                             __local int4 *blockA);
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top