Advertisement
Guest User

Untitled

a guest
Dec 29th, 2012
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 8.26 KB | None | 0 0
  1. /*
  2. Copyright 2011 Andreas Kloeckner
  3. Copyright 2008-2011 NVIDIA Corporation
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7.  
  8.     http://www.apache.org/licenses/LICENSE-2.0
  9.  
  10. Unless required by applicable law or agreed to in writing, software
  11. distributed under the License is distributed on an "AS IS" BASIS,
  12. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. See the License for the specific language governing permissions and
  14. limitations under the License.
  15.  
  16. Derived from thrust/detail/backend/cuda/detail/fast_scan.inl
  17. within the Thrust project, https://code.google.com/p/thrust/
  18.  
  19. Direct browse link:
  20. https://code.google.com/p/thrust/source/browse/thrust/detail/backend/cuda/detail/fast_scan.inl
  21. */
  22. #define SWG_SIZE 128
  23. #define K 8
  24.  
  25. #define UWG_SIZE 256
  26.  
  27. #define REQD_WG_SIZE(X,Y,Z) __attribute__((reqd_work_group_size(X, Y, Z)))
  28.  
  29. #define SCAN_EXPR(a, b) a+b
  30. typedef uint scan_type;
  31.  
  32. __kernel
  33. REQD_WG_SIZE(UWG_SIZE, 1, 1)
  34. void scan_final_update(
  35.   __global scan_type *output,
  36.   const uint N,
  37.   const uint interval_size,
  38.   __global scan_type *group_results)
  39. {
  40.   const uint interval_begin = interval_size * get_group_id(0);
  41.   const uint interval_end   = min(interval_begin + interval_size, N);
  42.  
  43.   if (get_group_id(0) == 0)
  44.     return;
  45.  
  46.   // value to add to this segment
  47.   scan_type prev_group_sum = group_results[get_group_id(0) - 1];
  48.  
  49.   // advance result pointer
  50.   output += interval_begin + get_local_id(0);
  51.  
  52.   for(uint unit_base = interval_begin;
  53.       unit_base < interval_end;
  54.       unit_base += UWG_SIZE, output += UWG_SIZE)
  55.   {
  56.     const uint i = unit_base + get_local_id(0);
  57.  
  58.     if(i < interval_end) {
  59.       *output = SCAN_EXPR(prev_group_sum, *output);
  60.     }
  61.   }
  62. }
  63.  
  64. void scan_group(__local scan_type *array)
  65. {
  66.   scan_type val = array[get_local_id(0)];
  67.  
  68.   for (uint offset=1; offset <= SWG_SIZE; offset *= 2) {
  69.     if (get_local_id(0) >= offset) {
  70.       scan_type tmp = array[get_local_id(0) - offset];
  71.       val = SCAN_EXPR(tmp, val);
  72.     }
  73.  
  74.     barrier(CLK_LOCAL_MEM_FENCE);
  75.     array[get_local_id(0)] = val;
  76.     barrier(CLK_LOCAL_MEM_FENCE);
  77.   }
  78. }
  79.  
  80. void scan_group_n(__local scan_type *array, const uint n)
  81. {
  82.   scan_type val = array[get_local_id(0)];
  83.  
  84.   for (uint offset=1; offset <= SWG_SIZE; offset *= 2) {
  85.     if (get_local_id(0) >= offset && get_local_id(0) < n) {
  86.       scan_type tmp = array[get_local_id(0) - offset];
  87.       val = SCAN_EXPR(tmp, val);
  88.     }
  89.  
  90.     barrier(CLK_LOCAL_MEM_FENCE);
  91.     array[get_local_id(0)] = val;
  92.     barrier(CLK_LOCAL_MEM_FENCE);
  93.   }
  94. }
  95.  
  96. __kernel
  97. REQD_WG_SIZE(SWG_SIZE, 1, 1)
  98. void scan_scan_intervals(
  99.   __global scan_type *input,
  100.   const uint N,
  101.   const uint interval_size,
  102.   __global scan_type *output,
  103.   __global scan_type *group_results)
  104. {
  105.   // padded in WG_SIZE to avoid bank conflicts
  106.   // index K in first dimension used for carry storage
  107.   __local scan_type ldata[K + 1][SWG_SIZE + 1];
  108.  
  109.   const uint interval_begin = interval_size * get_group_id(0);
  110.   const uint interval_end   = min(interval_begin + interval_size, N);
  111.  
  112.   const uint unit_size  = K * SWG_SIZE;
  113.  
  114.   uint unit_base = interval_begin;
  115.  
  116.  
  117.   for(; unit_base + unit_size <= interval_end; unit_base += unit_size) {
  118.     // Algorithm: Each work group is responsible for one contiguous
  119.     // 'interval', of which there are just enough to fill all compute
  120.     // units.  Intervals are split into 'units'. A unit is what gets
  121.     // worked on in parallel by one work group.
  122.  
  123.     // Each unit has two axes--the local-id axis and the k axis.
  124.     //
  125.     // * * * * * * * * * * ----> lid
  126.     // * * * * * * * * * *
  127.     // * * * * * * * * * *
  128.     // * * * * * * * * * *
  129.     // * * * * * * * * * *
  130.     // |
  131.     // v k
  132.  
  133.     // This is a three-phase algorithm, in which first each interval
  134.     // does its local scan, then a scan across intervals exchanges data
  135.     // globally, and the final update adds the exchanged sums to each
  136.     // interval.
  137.  
  138.     // Exclusive scan is realized by performing a right-shift inside
  139.     // the final update.
  140.  
  141.     // read a unit's worth of data from global
  142.  
  143.     for(uint k = 0; k < K; k++) {
  144.       const uint offset = k*SWG_SIZE + get_local_id(0);
  145.  
  146.       ldata[offset % K][offset / K] = input[unit_base + offset];
  147.     }
  148.  
  149.     // carry in from previous unit, if applicable.
  150.     if (get_local_id(0) == 0 && unit_base != interval_begin)
  151.       ldata[0][0] = SCAN_EXPR(ldata[K][SWG_SIZE - 1], ldata[0][0]);
  152.  
  153.     barrier(CLK_LOCAL_MEM_FENCE);
  154.  
  155.     // scan along k (sequentially in each work item)
  156.     scan_type sum = ldata[0][get_local_id(0)];
  157.  
  158.     for(uint k = 1; k < K; k++) {
  159.       scan_type tmp = ldata[k][get_local_id(0)];
  160.       sum = SCAN_EXPR(sum, tmp);
  161.       ldata[k][get_local_id(0)] = sum;
  162.     }
  163.  
  164.     // store carry in out-of-bounds (padding) array entry in the K direction
  165.     ldata[K][get_local_id(0)] = sum;
  166.     barrier(CLK_LOCAL_MEM_FENCE);
  167.  
  168.     // tree-based parallel scan along local id
  169.     scan_group(&ldata[K][0]);
  170.  
  171.     // update local values
  172.     if (get_local_id(0) > 0) {
  173.       sum = ldata[K][get_local_id(0) - 1];
  174.  
  175.       for(uint k = 0; k < K; k++) {
  176.       scan_type tmp = ldata[k][get_local_id(0)];
  177.       ldata[k][get_local_id(0)] = SCAN_EXPR(sum, tmp);
  178.       }
  179.     }
  180.  
  181.     barrier(CLK_LOCAL_MEM_FENCE);
  182.  
  183.     // write data
  184.     for(uint k = 0; k < K; k++) {
  185.       const uint offset = k*SWG_SIZE + get_local_id(0);
  186.  
  187.       output[unit_base + offset] = ldata[offset % K][offset / K];
  188.     }
  189.  
  190.     barrier(CLK_LOCAL_MEM_FENCE);
  191.   }
  192.  
  193.  
  194.   if (unit_base < interval_end) {
  195.     // Algorithm: Each work group is responsible for one contiguous
  196.     // 'interval', of which there are just enough to fill all compute
  197.     // units.  Intervals are split into 'units'. A unit is what gets
  198.     // worked on in parallel by one work group.
  199.  
  200.     // Each unit has two axes--the local-id axis and the k axis.
  201.     //
  202.     // * * * * * * * * * * ----> lid
  203.     // * * * * * * * * * *
  204.     // * * * * * * * * * *
  205.     // * * * * * * * * * *
  206.     // * * * * * * * * * *
  207.     // |
  208.     // v k
  209.  
  210.     // This is a three-phase algorithm, in which first each interval
  211.     // does its local scan, then a scan across intervals exchanges data
  212.     // globally, and the final update adds the exchanged sums to each
  213.     // interval.
  214.  
  215.     // Exclusive scan is realized by performing a right-shift inside
  216.     // the final update.
  217.  
  218.     // read a unit's worth of data from global
  219.  
  220.     for(uint k = 0; k < K; k++) {
  221.       const uint offset = k*SWG_SIZE + get_local_id(0);
  222.  
  223.       if (unit_base + offset < interval_end) {
  224.     ldata[offset % K][offset / K] = input[unit_base + offset];
  225.       }
  226.     }
  227.  
  228.     // carry in from previous unit, if applicable.
  229.     if (get_local_id(0) == 0 && unit_base != interval_begin)
  230.       ldata[0][0] = SCAN_EXPR(ldata[K][SWG_SIZE - 1], ldata[0][0]);
  231.  
  232.     barrier(CLK_LOCAL_MEM_FENCE);
  233.  
  234.     // scan along k (sequentially in each work item)
  235.     scan_type sum = ldata[0][get_local_id(0)];
  236.  
  237.     const uint offset_end = interval_end - unit_base;
  238.  
  239.     for(uint k = 1; k < K; k++) {
  240.       if (K * get_local_id(0) + k < offset_end) {
  241.     scan_type tmp = ldata[k][get_local_id(0)];
  242.     sum = SCAN_EXPR(sum, tmp);
  243.     ldata[k][get_local_id(0)] = sum;
  244.       }
  245.     }
  246.  
  247.     // store carry in out-of-bounds (padding) array entry in the K direction
  248.     ldata[K][get_local_id(0)] = sum;
  249.     barrier(CLK_LOCAL_MEM_FENCE);
  250.  
  251.     // tree-based parallel scan along local id
  252.     scan_group_n(&ldata[K][0], offset_end / K);
  253.  
  254.     // update local values
  255.     if (get_local_id(0) > 0) {
  256.       sum = ldata[K][get_local_id(0) - 1];
  257.  
  258.       for(uint k = 0; k < K; k++) {
  259.     if (K * get_local_id(0) + k < offset_end) {
  260.       scan_type tmp = ldata[k][get_local_id(0)];
  261.       ldata[k][get_local_id(0)] = SCAN_EXPR(sum, tmp);
  262.     }
  263.       }
  264.     }
  265.  
  266.     barrier(CLK_LOCAL_MEM_FENCE);
  267.  
  268.     // write data
  269.     for(uint k = 0; k < K; k++) {
  270.       const uint offset = k*SWG_SIZE + get_local_id(0);
  271.  
  272.       if (unit_base + offset < interval_end) {
  273.     output[unit_base + offset] = ldata[offset % K][offset / K];
  274.       }
  275.     }
  276.  
  277.     barrier(CLK_LOCAL_MEM_FENCE);
  278.   }
  279.  
  280.  
  281.   // write interval sum
  282.   if (get_local_id(0) == 0) {
  283.     group_results[get_group_id(0)] = output[interval_end - 1];
  284.   }
  285. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement