Advertisement
Guest User

Untitled

a guest
Dec 7th, 2019
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 8.50 KB | None | 0 0
  1.  
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <math.h>
  5. #include <string.h>
  6. #include <iostream>
  7. #define min(a,b) (((a) < (b)) ? (a):(b))
  8.  
  9. #define d_type float
  10. #define b_type int
  11. #define NParameter 6
  12.  
  13. #define CHin 32
  14. #define CHout 64
  15. #define R_in 128
  16. #define C_in 128
  17. #define K 6
  18. #define S 1
  19. #define R_out 126
  20. #define C_out 126
  21.  
  22. #define Tchout 4
  23. #define Tchin 4
  24. #define OUTBUF_CH 64
  25. #define OUTBUF_ROW 4
  26. #define OUTBUF_COL 4
  27. #define INBUF_CH 32
  28. #define INBUF_ROW 6
  29. #define INBUF_COL 6
  30. #define WBUF_CHO 64
  31. #define WBUF_CHI 32
  32. #define WBUF_K 6
  33.  
  34.  
  35. static b_type B[6][6] = { {4, 0, 0, 0, 0, 0},
  36.                          {0, -4, 4, -2, 2, 4},
  37.                          {-5, -4, -4, -1, -1, 0},
  38.                          {0, 1, -1, 2, -2, -5},
  39.                          {1, 1, 1, 1, 1, 0},
  40.                          {0, 0, 0, 0, 0, 1} };
  41.  
  42. static b_type BT[6][6] = { {4, 0, -5, 0, 1, 0},
  43.                           {0, -4, -4, 1, 1, 0},
  44.                           {0, 4, -4, -1, 1, 0},
  45.                           {0, -2, -1, 2, 1, 0},
  46.                           {0, 2, -1, -2, 1, 0},
  47.                           {0, 4, 0, -5, 0, 1} };
  48.  
  49. static b_type A[6][4] = { {1, 0, 0, 0},
  50.                          {1, 1, 1, 1},
  51.                          {1, -1, 1, -1},
  52.                          {1, 2, 4, 8},
  53.                          {1, -2, 4, -8},
  54.                          {0, 0, 0, 1} };
  55.  
  56. static b_type AT[4][6] = { {1, 1, 1, 1, 1, 0},
  57.                           {0, 1, -1, 2, -2, 0},
  58.                           {0, 1, 1, 4, 4, 0},
  59.                           {0, 1, -1, 8, -8, 1} };
  60.  
  61. d_type In[CHin][R_in][C_in], W[CHout][CHin][K][K];
  62. d_type Out[CHout][R_out][C_out], Gold[CHout][R_out][C_out];
  63.  
  64.  
  65. static inline int cal_index_3d(int ch, int r, int c, int ROW, int COL) {
  66.     return c + (r + ch * ROW) * COL;
  67. }
  68.  
  69. static inline int cal_index_4d(int chout, int chin, int r, int c,
  70.     int CHIN, int ROW, int COL) {
  71.     return c + (r + (chin + chout * CHIN) * ROW) * COL;
  72. }
  73.  
  74.  
  75. void load_weight(d_type w_buf[WBUF_CHO][WBUF_CHI][WBUF_K][WBUF_K],
  76.     d_type* weight_array)
  77. {
  78.     for (int cho = 0; cho < WBUF_CHO; ++cho)
  79.         for (int chi = 0; chi < WBUF_CHI; ++chi)
  80.             for (int i = 0; i < WBUF_K; ++i)
  81.                 for (int j = 0; j < WBUF_K; ++j)
  82.                     w_buf[cho][chi][i][j] = weight_array[cal_index_4d(cho, chi, i, j, CHin, K, K)];
  83. }
  84.  
  85.  
  86.  
  87. void load_input_buf(int row_start, int row_len,
  88.     int col_start, int col_len,
  89.     d_type in_buf[INBUF_CH][INBUF_ROW][INBUF_COL],
  90.     d_type* in_array)
  91. {
  92.  
  93.     for (int chi = 0; chi < CHin; ++chi)
  94.     {
  95.         for (int row_in_bias = 0, row_in = row_start * S; row_in_bias < row_len; ++row_in_bias, ++row_in)
  96.         {
  97.             for (int col_in_bias = 0, col_in = col_start * S; col_in_bias < col_len; ++col_in_bias, ++col_in)
  98.             {
  99.                 in_buf[chi][row_in_bias][col_in_bias] = in_array[cal_index_3d(chi, row_in, col_in, R_in, C_in)];
  100.             }
  101.         }
  102.     }
  103. }
  104.  
  105.  
  106.  
  107. void clear_out_buf(d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL])
  108. {
  109.     for (int row = 0; row < OUTBUF_ROW; ++row)
  110.         for (int col = 0; col < OUTBUF_COL; ++col)
  111.             for (int cho = 0; cho < OUTBUF_CH; ++cho)
  112.                 out_buf[cho][row][col] = 0;
  113. }
  114.  
  115.  
  116.  
  117. void store_out_buf(int row_start, int row_len, int col_start, int col_len,
  118.     d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL],
  119.     d_type* out_array)
  120. {
  121.     for (int cho = 0; cho < CHout; ++cho)
  122.         for (int row_out_bias = 0, row_out = row_start; row_out_bias < row_len; ++row_out_bias, ++row_out)
  123.             for (int col_out_bias = 0, col_out = col_start; col_out_bias < col_len; ++col_out_bias, ++col_out)
  124.  
  125.                 out_array[cal_index_3d(cho, row_out, col_out, R_out, C_out)] = out_buf[cho][row_out_bias][col_out_bias];
  126. }
  127.  
  128.  
  129.  
  130. void winograd(int to, int ti,
  131.     d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL],
  132.     d_type in_buf[INBUF_CH][INBUF_ROW][INBUF_COL],
  133.     d_type w_buf[WBUF_CHO][WBUF_CHI][WBUF_K][WBUF_K], int flag)
  134. {
  135.     d_type uv[6][6], u[6][6], t[6][6], tmp = 0;
  136.  
  137.     // U = B^T Z B
  138.     for (int i = 0; i < 6; i++)
  139.     {
  140.         for (int j = 0; j < 6; j++)
  141.         {
  142.             tmp = 0;
  143.             for (int k = 0; k < 6; k++)
  144.             {
  145.                 tmp += BT[i][k] * in_buf[ti][k][j];
  146.             }
  147.             t[i][j] = tmp;
  148.         }
  149.     }
  150.  
  151.     for (int i = 0; i < 6; i++)
  152.     {
  153.         for (int j = 0; j < 6; j++)
  154.         {
  155.             tmp = 0;
  156.             for (int k = 0; k < 6; k++)
  157.             {
  158.                 tmp += t[i][k] * B[k][j];
  159.             }
  160.             u[i][j] = tmp;
  161.         }
  162.     }
  163.  
  164.     for (int i = 0; i < 6; i++)
  165.     {
  166.         for (int j = 0; j < 6; j++)
  167.         {
  168.             uv[i][j] = u[i][j] * w_buf[to][ti][i][j];
  169.         }
  170.     }
  171.  
  172.     // Y = A^T UV A
  173.     for (int i = 0; i < 4; i++)
  174.     {
  175.         for (int j = 0; j < 6; j++)
  176.         {
  177.             tmp = 0;
  178.             for (int k = 0; k < 6; k++)
  179.             {
  180.                 tmp += AT[i][k] * uv[k][j];
  181.             }
  182.             t[i][j] = tmp;
  183.         }
  184.     }
  185.  
  186.     for (int i = 0; i < 4; i++)
  187.     {
  188.         for (int j = 0; j < 4; j++)
  189.         {
  190.             tmp = 0;
  191.             for (int k = 0; k < 6; k++)
  192.             {
  193.                 tmp += uv[i][k] * A[k][j];
  194.             }
  195.             out_buf[to][i][j] += tmp;
  196.         }
  197.     }
  198.  
  199.     if (flag == 1)
  200.     {
  201.         printf("**********************ti = %d**************************\n", ti);
  202.  
  203.         for (int i = 0; i < 6; i++)
  204.         {
  205.             for (int j = 0; j < 6; j++)
  206.             {
  207.                 printf("in_buf[%d][%d][%d] is %f\n", ti, i, j, in_buf[ti][i][j]);
  208.  
  209.             }
  210.         }
  211.         printf("-------------------------------------------------------\n");
  212.         for (int i = 0; i < 6; i++)
  213.         {
  214.             for (int j = 0; j < 6; j++)
  215.             {
  216.                 printf("w_buf[0][%d][%d][%d] is %f\n", ti, i, j, w_buf[0][ti][i][j]);
  217.  
  218.             }
  219.         }
  220.         printf("-------------------------------------------------------\n");
  221.         for (int i = 0; i < 4; i++)
  222.         {
  223.             for (int j = 0; j < 4; j++)
  224.             {
  225.                 printf("out_buf[0][%d][%d] is %f\n", i, j, out_buf[0][i][j]);
  226.  
  227.             }
  228.         }
  229.  
  230.     }
  231.  
  232. }
  233.  
  234.  
  235. void cnn(d_type* In, d_type* Out, d_type* W, int* Parameter)
  236. {
  237.  
  238.     /*
  239.     In  : Input feature map, CHin*R*C
  240.     Out : Output feature map, CHout*Rout*Cout
  241.     W : weights, CHout*CHin*Kr*Kc
  242.     Parameter:  CHout|CHin|R|C|K|S
  243.     */
  244.  
  245.     d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL];
  246.     d_type in_buf[INBUF_CH][INBUF_ROW][INBUF_COL];
  247.     d_type w_buf[WBUF_CHO][WBUF_CHI][WBUF_K][WBUF_K];
  248.  
  249.     load_weight(w_buf, W);      // Load all weight
  250.  
  251.  
  252. OUTPUT_ROW:
  253.     for (int row = 0; row < R_out; row += OUTBUF_ROW) {
  254.         row = (row > 122) ? 122 : row;
  255.        
  256.     OUTPUT_COLOMN:
  257.         for (int col = 0; col < C_out; col += OUTBUF_COL) {
  258.             col = (col > 122) ? 122 : col;
  259.  
  260.             load_input_buf(row, INBUF_ROW, col, INBUF_COL, in_buf, In);
  261.  
  262.             clear_out_buf(out_buf);
  263.  
  264.         OUTPUT_CHANNEL:
  265.             for (int cho = 0; cho < CHout; cho += Tchout) {
  266.                
  267.             INPUT_CHANNEL:
  268.                 for (int chi = 0; chi < CHin; chi += Tchin) {
  269.  
  270.                     TILED_OUTPUT_CHANNEL:
  271.                     for (int to_bias = 0; to_bias < min(CHout - cho, Tchout); ++to_bias) {
  272.                        
  273.                         TILED_INPUT_CHANNEL:
  274.                         for (int ti_bias = 0; ti_bias < min(CHin - chi, Tchin); ++ti_bias) {
  275.                             int flag = 0;
  276.                             if (row == 0 && col == 0 && cho + to_bias == 0) {
  277.                                 flag = 1;
  278.                             }
  279.                             winograd(cho + to_bias, chi + ti_bias, out_buf, in_buf, w_buf, flag);
  280.  
  281.                         }
  282.                     }
  283.                 }
  284.             }
  285.             if (row == 0 && col == 0) {
  286.                 for (int i = 0; i < 4; i++)
  287.                 {
  288.                     for (int j = 0; j < 4; j++)
  289.                     {
  290.                         printf("out_buf[%d][%d] is %f\n", i, j, out_buf[0][i][j]);
  291.                        
  292.                     }
  293.                 }
  294.                
  295.             }
  296.             store_out_buf(row, OUTBUF_ROW, col, OUTBUF_COL, out_buf, Out);
  297.  
  298.         }
  299.     }
  300. }
  301.  
  302.  
  303. int main() {
  304.     FILE* fin;
  305.     FILE* fout;
  306.     FILE* fw;
  307.  
  308.     errno_t err;
  309.  
  310.     err = fopen_s(&fin, "D:\\CourseDocument\\ComputerArchitecturePractice\\Lab7\\Lab7\\dat\\sample_input.dat", "r");
  311.     err = fopen_s(&fout, "D:\\CourseDocument\\ComputerArchitecturePractice\\Lab7\\Lab7\\dat\\sample_out.dat", "r");
  312.     err = fopen_s(&fw, "D:\\CourseDocument\\ComputerArchitecturePractice\\Lab7\\Lab7\\dat\\sample_weight.dat", "r");
  313.  
  314.     for (int i = 0; i < CHin; i++)
  315.     {
  316.         for (int j = 0; j < R_in; j++)
  317.         {
  318.             for (int x = 0; x < C_in; x++)
  319.             {
  320.                 fscanf_s(fin, "%f", &In[i][j][x]);
  321.  
  322.             }
  323.         }
  324.     }
  325.  
  326.     for (int i = 0; i < CHout; i++)
  327.     {
  328.         for (int j = 0; j < CHin; j++)
  329.         {
  330.             for (int x = 0; x < K; x++)
  331.             {
  332.                 for (int y = 0; y < K; y++)
  333.                 {
  334.                     fscanf_s(fw, "%f\n", &W[i][j][x][y]);
  335.                 }
  336.             }
  337.         }
  338.     }
  339.  
  340.    
  341.     for (int i = 0; i < CHout; i++)
  342.     {
  343.         for (int j = 0; j < R_out; j++)
  344.         {
  345.             for (int x = 0; x < C_out; x++)
  346.             {
  347.                 Out[i][j][x] = 0;
  348.             }
  349.         }
  350.     }
  351.    
  352.     int Parameter[] = { CHin, CHout, R_in, C_in, K, S };
  353.     cnn(In[0][0], Out[0][0], W[0][0][0], Parameter);
  354.  
  355.     for (int i = 0; i < CHout; i++)
  356.     {
  357.         for (int j = 0; j < R_out; j++)
  358.         {
  359.             for (int x = 0; x < C_out; x++)
  360.             {
  361.                 fscanf_s(fout, "%f\n", &Gold[i][j][x]);
  362.             }
  363.         }
  364.     }
  365.  
  366.     float eps1 = 0.01;
  367.     int flag = 0;
  368.     for (int i = 0; i < CHout; i++)
  369.     {
  370.         for (int j = 0; j < R_out; j++)
  371.         {
  372.             for (int x = 0; x < C_out; x++)
  373.             {
  374.  
  375.                 if (fabs(Gold[i][j][x] - Out[i][j][x]) > eps1)
  376.                 {
  377.                     printf("Not Match!\n");
  378.                     printf("Gold is %f\tOut is %f\n", Gold[i][j][x], Out[i][j][x]);
  379.                     flag += 1;
  380.                     if (flag == 10) {
  381.                         fclose(fin);
  382.                         fclose(fout);
  383.                         fclose(fw);
  384.                         return 0;
  385.                     }
  386.                 }
  387.             }
  388.         }
  389.     }
  390.     printf("Match!\n");
  391.     fclose(fin);
  392.     fclose(fout);
  393.     fclose(fw);
  394.  
  395.     return 0;
  396. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement