Guest User

用显卡(cuda)计算扫雷高级局面的3BV

a guest
Sep 23rd, 2022
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 16.40 KB | Source Code | 0 0
  1.  
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <cstring>
  5. #include <time.h>
  6. #include "cuda_runtime.h"
  7. #include "device_launch_parameters.h"
  8. #include <curand_kernel.h>
  9. #include <curand.h>
  10. #include <stack>
  11. // #define GRIDSIZE 2048;
  12. // #define BLOCKSIZE 128;
  13. // nvcc main.cu --run
  14. // Connected-Component-Labelling
  15.  
  16. // 同时计算的局面数,即线程块的数量
  17. const int N = 1024;
  18. // -1:雷
  19. // -10:1~8
  20. // 0~60:空及其标记
  21.  
  22. __device__ unsigned int curand_int(unsigned int n, curandStateMRG32k3a_t* state) {
  23.     // 0到n-1之间的随机数,闭区间
  24.     if (n == 0) {
  25.         return 0;
  26.     }
  27.     unsigned int t = curand(state);
  28.     while (t > (0xffffffff / n * n)) {
  29.         t = curand(state);
  30.     }
  31.     return t % n;
  32.     return t;
  33. }
  34.  
  35.  
  36.  
  37. __global__ void init_all_curand(curandStateMRG32k3a_t *states) {
  38.     // 对所有随机数种子初始化
  39.     unsigned long long tid = threadIdx.x + blockIdx.x * 99;
  40.     unsigned long long subsequence = threadIdx.x;
  41.     unsigned long long offset = 0;
  42.     curand_init(tid, subsequence, offset, &states[tid]);
  43.     // printf("%3d, ", tid);
  44.     // printf("%u, ", curand(&states[tid]));
  45. }
  46.  
  47.  
  48.  
  49. __device__ void cal_num_for_cell(int *cuda_board, unsigned int tid, unsigned int x, unsigned int y) {
  50.     if (cuda_board[tid] == -1) {
  51.         return;
  52.     }
  53.     if (x == 0) {
  54.         // 正下
  55.         if (cuda_board[tid + 1] == -1) {
  56.             cuda_board[tid] = -10;
  57.             return;
  58.         }
  59.         if (y == 29) {
  60.             // 正左、左下角
  61.             if (cuda_board[tid - 16] == -1 || cuda_board[tid - 15] == -1) {
  62.                 cuda_board[tid] = -10;
  63.                 return;
  64.             }
  65.         } else if (y == 0) {
  66.             // 正右、右下角
  67.             if (cuda_board[tid + 16] == -1 || cuda_board[tid + 17] == -1) {
  68.                 cuda_board[tid] = -10;
  69.                 return;
  70.             }
  71.         } else {
  72.             if (cuda_board[tid - 16] == -1 || cuda_board[tid - 15] == -1 || cuda_board[tid + 16] == -1 || cuda_board[tid + 17] == -1) {
  73.                 cuda_board[tid] = -10;
  74.                 return;
  75.             }
  76.         }
  77.     } else if (x == 15) {
  78.         if (cuda_board[tid - 1] == -1) {
  79.             cuda_board[tid] = -10;
  80.             return;
  81.         }
  82.         if (y == 29) {
  83.             // 正左、左下角
  84.             if (cuda_board[tid - 16] == -1 || cuda_board[tid - 17] == -1) {
  85.                 cuda_board[tid] = -10;
  86.                 return;
  87.             }
  88.         } else if (y == 0) {
  89.             // 正右、右下角
  90.             if (cuda_board[tid + 16] == -1 || cuda_board[tid + 15] == -1) {
  91.                 cuda_board[tid] = -10;
  92.                 return;
  93.             }
  94.         } else {
  95.             if (cuda_board[tid - 16] == -1 || cuda_board[tid - 17] == -1 || cuda_board[tid + 16] == -1 || cuda_board[tid + 15] == -1) {
  96.                 cuda_board[tid] = -10;
  97.                 return;
  98.             }
  99.         }
  100.     } else {
  101.         if (cuda_board[tid - 1] == -1 || cuda_board[tid + 1] == -1) {
  102.             cuda_board[tid] = -10;
  103.             return;
  104.         }
  105.         if (y == 29) {
  106.             // 左
  107.             if (cuda_board[tid - 16] == -1 || cuda_board[tid - 17] == -1 || cuda_board[tid - 15] == -1) {
  108.                 cuda_board[tid] = -10;
  109.                 return;
  110.             }
  111.         } else if (y == 0) {
  112.             // 正右、右下角
  113.             if (cuda_board[tid + 16] == -1 || cuda_board[tid + 15] == -1 || cuda_board[tid + 17] == -1) {
  114.                 cuda_board[tid] = -10;
  115.                 return;
  116.             }
  117.         } else {
  118.             if (cuda_board[tid - 16] == -1 || cuda_board[tid - 17] == -1 || cuda_board[tid + 16] == -1 || cuda_board[tid + 15] == -1 ||
  119.             cuda_board[tid + 17] == -1 || cuda_board[tid - 15] == -1) {
  120.                 cuda_board[tid] = -10;
  121.                 return;
  122.             }
  123.         }
  124.     }
  125. }
  126.  
  127. __device__ bool is_3BV_on_island (int *cuda_board, unsigned int tid, unsigned int x, unsigned int y) {
  128.     // cuda_board[tid]是数字的情况下,判断是不是一个3BV
  129.     if (x == 0) {
  130.         // 正下
  131.         if (cuda_board[tid + 1] >= 0) {
  132.             return false;
  133.         }
  134.         if (y == 29) {
  135.             // 正左、左下角
  136.             if (cuda_board[tid - 16] >= 0 || cuda_board[tid - 15] >= 0) {
  137.                 return false;
  138.             }
  139.         } else if (y == 0) {
  140.             // 正右、右下角
  141.             if (cuda_board[tid + 16] >= 0 || cuda_board[tid + 17] >= 0) {
  142.                 return false;
  143.             }
  144.         } else {
  145.             if (cuda_board[tid - 16] >= 0 || cuda_board[tid - 15] >= 0 || cuda_board[tid + 16] >= 0 || cuda_board[tid + 17] >= 0) {
  146.                 return false;
  147.             }
  148.         }
  149.     } else if (x == 15) {
  150.         if (cuda_board[tid - 1] >= 0) {
  151.             return false;
  152.         }
  153.         if (y == 29) {
  154.             // 正左、左下角
  155.             if (cuda_board[tid - 16] >= 0 || cuda_board[tid - 17] >= 0) {
  156.                 return false;
  157.             }
  158.         } else if (y == 0) {
  159.             // 正右、右下角
  160.             if (cuda_board[tid + 16] >= 0 || cuda_board[tid + 15] >= 0) {
  161.                 return false;
  162.             }
  163.         } else {
  164.             if (cuda_board[tid - 16] >= 0 || cuda_board[tid - 17] >= 0 || cuda_board[tid + 16] >= 0 || cuda_board[tid + 15] >= 0) {
  165.                 return false;
  166.             }
  167.         }
  168.     } else {
  169.         if (cuda_board[tid - 1] >= 0 || cuda_board[tid + 1] >= 0) {
  170.             return false;
  171.         }
  172.         if (y == 29) {
  173.             // 左
  174.             if (cuda_board[tid - 16] >= 0 || cuda_board[tid - 17] >= 0 || cuda_board[tid - 15] >= 0) {
  175.                 return false;
  176.             }
  177.         } else if (y == 0) {
  178.             // 正右、右下角
  179.             if (cuda_board[tid + 16] >= 0 || cuda_board[tid + 15] >= 0 || cuda_board[tid + 17] >= 0) {
  180.                 return false;
  181.             }
  182.         } else {
  183.             if (cuda_board[tid - 16] >= 0 || cuda_board[tid - 17] >= 0 || cuda_board[tid + 16] >= 0 || cuda_board[tid + 15] >= 0 ||
  184.             cuda_board[tid + 17] >= 0 || cuda_board[tid - 15] >= 0) {
  185.                 return false;
  186.             }
  187.         }
  188.     }
  189.     return true;
  190. }
  191.  
  192. __device__ void check_surrounding_cells(int *board, bool *board_is_3BV, int *bbbv_current, bool *keep_running, unsigned int tid, unsigned int x, unsigned int y, unsigned int idN, unsigned int cell) {
  193.     // cuda_board[tid]是空(>=0)的情况下,处理八方
  194.     if (board[tid] == 0) {
  195.         // printf("666");
  196.         keep_running[idN] = true;
  197.         board[tid] = cell + 1;
  198.         board_is_3BV[tid] = true;
  199.         // atomicAdd(&bbbv_current[idN], 1);
  200.         return;
  201.     }
  202.     int min_op_id = 500;
  203.     if (x == 0) {
  204.         // 正下
  205.         if (board[tid + 1] > 0 && board[tid + 1] < min_op_id) {
  206.             min_op_id = board[tid + 1];
  207.         }
  208.         if (y == 29) {
  209.             // 正左、左下角
  210.             if (board[tid - 16] > 0 && board[tid - 16] < min_op_id) {
  211.                 min_op_id = board[tid - 16];
  212.             }
  213.             if (board[tid - 15] > 0 && board[tid - 15] < min_op_id) {
  214.                 min_op_id = board[tid - 15];
  215.             }
  216.         } else if (y == 0) {
  217.             // 正右、右下角
  218.             if (board[tid + 16] > 0 && board[tid + 16] < min_op_id) {
  219.                 min_op_id = board[tid + 16];
  220.             }
  221.             if (board[tid + 17] > 0 && board[tid + 17] < min_op_id) {
  222.                 min_op_id = board[tid + 17];
  223.             }
  224.         } else {
  225.             if (board[tid - 16] > 0 && board[tid - 16] < min_op_id) {
  226.                 min_op_id = board[tid - 16];
  227.             }
  228.             if (board[tid - 15] > 0 && board[tid - 15] < min_op_id) {
  229.                 min_op_id = board[tid - 15];
  230.             }
  231.             if (board[tid + 16] > 0 && board[tid + 16] < min_op_id) {
  232.                 min_op_id = board[tid + 16];
  233.             }
  234.             if (board[tid + 17] > 0 && board[tid + 17] < min_op_id) {
  235.                 min_op_id = board[tid + 17];
  236.             }
  237.         }
  238.     } else if (x == 15) {
  239.         if (board[tid - 1] > 0 && board[tid - 1] < min_op_id) {
  240.             min_op_id = board[tid - 1];
  241.         }
  242.         if (y == 29) {
  243.             // 正左、左下角
  244.             if (board[tid - 16] > 0 && board[tid - 16] < min_op_id) {
  245.                 min_op_id = board[tid - 16];
  246.             }
  247.             if (board[tid - 17] > 0 && board[tid - 17] < min_op_id) {
  248.                 min_op_id = board[tid - 17];
  249.             }
  250.         } else if (y == 0) {
  251.             // 正右、右下角
  252.             if (board[tid + 15] > 0 && board[tid + 15] < min_op_id) {
  253.                 min_op_id = board[tid + 15];
  254.             }
  255.             if (board[tid + 16] > 0 && board[tid + 16] < min_op_id) {
  256.                 min_op_id = board[tid + 16];
  257.             }
  258.         } else {
  259.             if (board[tid - 16] > 0 && board[tid - 16] < min_op_id) {
  260.                 min_op_id = board[tid - 16];
  261.             }
  262.             if (board[tid - 17] > 0 && board[tid - 17] < min_op_id) {
  263.                 min_op_id = board[tid - 17];
  264.             }
  265.             if (board[tid + 15] > 0 && board[tid + 15] < min_op_id) {
  266.                 min_op_id = board[tid + 15];
  267.             }
  268.             if (board[tid + 16] > 0 && board[tid + 16] < min_op_id) {
  269.                 min_op_id = board[tid + 16];
  270.             }
  271.         }
  272.     } else {
  273.         if (board[tid - 1] > 0 && board[tid - 1] < min_op_id) {
  274.             min_op_id = board[tid - 1];
  275.         }
  276.         if (board[tid + 1] > 0 && board[tid + 1] < min_op_id) {
  277.             min_op_id = board[tid + 1];
  278.         }
  279.         if (y == 29) {
  280.             // 左
  281.             if (board[tid - 15] > 0 && board[tid - 15] < min_op_id) {
  282.                 min_op_id = board[tid - 15];
  283.             }
  284.             if (board[tid - 16] > 0 && board[tid - 16] < min_op_id) {
  285.                 min_op_id = board[tid - 16];
  286.             }
  287.             if (board[tid - 17] > 0 && board[tid - 17] < min_op_id) {
  288.                 min_op_id = board[tid - 17];
  289.             }
  290.         } else if (y == 0) {
  291.             // 正右、右下角
  292.             if (board[tid + 15] > 0 && board[tid + 15] < min_op_id) {
  293.                 min_op_id = board[tid + 15];
  294.             }
  295.             if (board[tid + 16] > 0 && board[tid + 16] < min_op_id) {
  296.                 min_op_id = board[tid + 16];
  297.             }
  298.             if (board[tid + 17] > 0 && board[tid + 17] < min_op_id) {
  299.                 min_op_id = board[tid + 17];
  300.             }
  301.         } else {
  302.             if (board[tid - 15] > 0 && board[tid - 15] < min_op_id) {
  303.                 min_op_id = board[tid - 15];
  304.             }
  305.             if (board[tid - 16] > 0 && board[tid - 16] < min_op_id) {
  306.                 min_op_id = board[tid - 16];
  307.             }
  308.             if (board[tid - 17] > 0 && board[tid - 17] < min_op_id) {
  309.                 min_op_id = board[tid - 17];
  310.             }
  311.             if (board[tid + 15] > 0 && board[tid + 15] < min_op_id) {
  312.                 min_op_id = board[tid + 15];
  313.             }
  314.             if (board[tid + 16] > 0 && board[tid + 16] < min_op_id) {
  315.                 min_op_id = board[tid + 16];
  316.             }
  317.             if (board[tid + 17] > 0 && board[tid + 17] < min_op_id) {
  318.                 min_op_id = board[tid + 17];
  319.             }
  320.         }
  321.     }
  322.     // printf("'%u', ", min_op_id);
  323.     if (min_op_id < board[tid]) {
  324.         // printf("666");
  325.         board[tid] = min_op_id;
  326.         keep_running[idN] = true;
  327.         if (board_is_3BV[tid]) {
  328.             board_is_3BV[tid] = false;
  329.             // atomicSub(&bbbv_current[idN], 1);
  330.         }
  331.     }
  332.     return;
  333. }
  334.  
  335. __global__ void laymine_cal_3BV(int *board, unsigned long long *bbbv, curandStateMRG32k3a_t *states, bool *board_is_3BV, int *bbbv_current, bool *keep_running) {
  336.    
  337.     // 全局的第几个格子
  338.     unsigned long long tid = threadIdx.x + blockIdx.x * 480;
  339.     // 第几局
  340.     const unsigned int idN = (unsigned int) blockIdx.x;
  341.     // 计算是本局的第几个格子
  342.     const unsigned int cell = tid % 480;
  343.     // 第几行
  344.     const unsigned int x = cell & 0x0000000f;
  345.     // 第几列
  346.     const unsigned int y = cell >> 4;
  347.     // 480个线程,前99个同时负责埋雷
  348.     bool can_laymine = (cell < 99)? true:false;
  349.     const unsigned int mines_id = blockIdx.x * 99 + cell;
  350.     const unsigned int mines_id_start_at = tid - cell;
  351.  
  352.     for (unsigned long long k = 0; k < 1000; k++) {
  353.         // 循环:完成埋雷、算3BV
  354.         // 初始化整个局面
  355.         board[tid] = 0;
  356.         board_is_3BV[tid] = false;
  357.         __syncthreads();
  358.         // 并行地为所有局面的所有位置埋雷
  359.         if (can_laymine) {
  360.             int old;
  361.             do {
  362.                 int cell_id = curand_int(480, &states[mines_id]) + mines_id_start_at;
  363.                 old = atomicCAS(&board[cell_id], 0, -1);
  364.             } while (old == -1);
  365.         }
  366.         // 块内同步
  367.         __syncthreads();
  368.         cal_num_for_cell(board, tid, x, y);
  369.         // 块内同步
  370.         __syncthreads();
  371.         do {
  372.             // 初始化op计算完成标志
  373.             keep_running[idN] = false;
  374.             __syncthreads();
  375.             if (board[tid] == -10) {
  376.                 if (!board_is_3BV[tid]) {
  377.                     board_is_3BV[tid] = true;
  378.                     if (is_3BV_on_island (board, tid, x, y)) {
  379.                         atomicAdd(&bbbv_current[idN], 1);
  380.                     }
  381.                 }
  382.             } else if (board[tid] >= 0) {
  383.                 check_surrounding_cells(board, board_is_3BV, bbbv_current, keep_running, tid, x, y, idN, cell);
  384.                 // flag_ok = false;
  385.             }
  386.             __syncthreads();
  387.         } while(keep_running[idN]);
  388.        
  389.         if (board_is_3BV[tid] && board[tid] >= 0) {
  390.             atomicAdd(&bbbv_current[idN], 1);
  391.         }
  392.         __syncthreads();
  393.  
  394.         if (threadIdx.x == 0) {
  395.             atomicAdd(&bbbv[bbbv_current[idN]], 1);
  396.             // printf("666, %d; ", bbbv_current[idN]);
  397.             bbbv_current[idN] = 0;
  398.         }
  399.         // 每个线程块的第一个线程负责统计结果
  400.  
  401.     }
  402. }
  403.  
  404. int main(void) {
  405.  
  406.     // int nx = GRIDSIZE;
  407.     // int ny = BLOCKSIZE;
  408.  
  409.     // bool *flag_ok = true;
  410.     bool *cuda_flag_ok; // 连通域计数时,是否全部计算完
  411.     cudaMalloc(&cuda_flag_ok, sizeof(bool));
  412.     // cudaMemcpy(cuda_flag_ok, flag_ok, sizeof(bool), cudaMemcpyHostToDevice);
  413.    
  414.     const int M_bbbv = sizeof(unsigned long long) * 382;
  415.  
  416.     // 3BV记录在此
  417.     unsigned long long bbbv[382] = {0};
  418.     unsigned long long *cuda_bbbv;
  419.     cudaMalloc(&cuda_bbbv, M_bbbv);
  420.     cudaMemcpy(cuda_bbbv, bbbv, M_bbbv, cudaMemcpyHostToDevice);
  421.  
  422.     // 是否继续跑
  423.     bool *cuda_keep_running;
  424.     cudaMalloc(&cuda_keep_running, sizeof(bool) * N);
  425.  
  426.     int *cuda_bbbv_current;
  427.     cudaMalloc(&cuda_bbbv_current, sizeof(int) * N);
  428.  
  429.     // 随机数状态
  430.     curandStateMRG32k3a_t *cuda_states;
  431.     cudaMalloc(&cuda_states, sizeof(curandStateMRG32k3a_t) * 99 * N);
  432.     // 初始化随机数
  433.     init_all_curand <<< N, 99 >>> (cuda_states);
  434.     int call_back = cudaDeviceSynchronize();
  435.  
  436.     // 辅助记录是不是op的变量
  437.     bool *cuda_board_is_3BV;
  438.     cudaMalloc(&cuda_board_is_3BV, sizeof(bool) * 480 * N);
  439.  
  440.     // 局面
  441.     int *cuda_board;
  442.     cudaMalloc(&cuda_board, sizeof(int) * 480 * N);
  443.  
  444.     clock_t start, finish;
  445.     float costtime;
  446.     start = clock();
  447.  
  448.     // 埋雷、计算3bv
  449.     laymine_cal_3BV <<< N, 480 >>> (cuda_board, cuda_bbbv, cuda_states, cuda_board_is_3BV, cuda_bbbv_current, cuda_keep_running);
  450.  
  451.     call_back = cudaDeviceSynchronize();
  452.     finish = clock();
  453.     //得到两次记录之间的时间差
  454.     costtime = (float)(finish - start) / CLOCKS_PER_SEC;
  455.  
  456.     cudaMemcpy(&bbbv, cuda_bbbv, M_bbbv, cudaMemcpyDeviceToHost);
  457.  
  458.     // //time_t t;
  459.     // //srand((unsigned)time(&t));
  460.     // //printf(" %d \n", (int)(rand() & 0xff));
  461.  
  462.     int aaa = 0;
  463.     for (int n = 0; n < 381; ++n)
  464.     {
  465.         aaa += bbbv[n];
  466.         printf("%d: %u\n", n, bbbv[n]);
  467.     }
  468.     printf("一共: %d\n", aaa);
  469.     printf("耗时:%f \n", costtime);
  470.     printf("速度:%f \n", aaa/costtime);
  471.  
  472.     cudaDeviceReset();
  473.     cudaFree(cuda_bbbv);
  474.     cudaFree(cuda_keep_running);
  475.     cudaFree(cuda_bbbv_current);
  476.     cudaFree(cuda_flag_ok);
  477.     cudaFree(cuda_states);
  478.     cudaFree(cuda_board);
  479.     cudaFree(cuda_board_is_3BV);
  480.     return 0;
  481.  
  482.  
  483.  
  484. }
Advertisement
Add Comment
Please, Sign In to add comment