Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <stdio.h>
- #include <stdlib.h>
- #include <math.h>
- #include <string.h>
- #include <iostream>
- #define min(a,b) (((a) < (b)) ? (a):(b))
- #define d_type float
- #define b_type int
- #define NParameter 6
- #define CHin 32
- #define CHout 64
- #define R_in 128
- #define C_in 128
- #define K 6
- #define S 1
- #define R_out 126
- #define C_out 126
- #define Tchout 4
- #define Tchin 4
- #define OUTBUF_CH 64
- #define OUTBUF_ROW 4
- #define OUTBUF_COL 4
- #define INBUF_CH 32
- #define INBUF_ROW 6
- #define INBUF_COL 6
- #define WBUF_CHO 64
- #define WBUF_CHI 32
- #define WBUF_K 6
- static b_type B[6][6] = { {4, 0, 0, 0, 0, 0},
- {0, -4, 4, -2, 2, 4},
- {-5, -4, -4, -1, -1, 0},
- {0, 1, -1, 2, -2, -5},
- {1, 1, 1, 1, 1, 0},
- {0, 0, 0, 0, 0, 1} };
- static b_type BT[6][6] = { {4, 0, -5, 0, 1, 0},
- {0, -4, -4, 1, 1, 0},
- {0, 4, -4, -1, 1, 0},
- {0, -2, -1, 2, 1, 0},
- {0, 2, -1, -2, 1, 0},
- {0, 4, 0, -5, 0, 1} };
- static b_type A[6][4] = { {1, 0, 0, 0},
- {1, 1, 1, 1},
- {1, -1, 1, -1},
- {1, 2, 4, 8},
- {1, -2, 4, -8},
- {0, 0, 0, 1} };
- static b_type AT[4][6] = { {1, 1, 1, 1, 1, 0},
- {0, 1, -1, 2, -2, 0},
- {0, 1, 1, 4, 4, 0},
- {0, 1, -1, 8, -8, 1} };
- d_type In[CHin][R_in][C_in], W[CHout][CHin][K][K];
- d_type Out[CHout][R_out][C_out], Gold[CHout][R_out][C_out];
- static inline int cal_index_3d(int ch, int r, int c, int ROW, int COL) {
- return c + (r + ch * ROW) * COL;
- }
- static inline int cal_index_4d(int chout, int chin, int r, int c,
- int CHIN, int ROW, int COL) {
- return c + (r + (chin + chout * CHIN) * ROW) * COL;
- }
- void load_weight(d_type w_buf[WBUF_CHO][WBUF_CHI][WBUF_K][WBUF_K],
- d_type* weight_array)
- {
- for (int cho = 0; cho < WBUF_CHO; ++cho)
- for (int chi = 0; chi < WBUF_CHI; ++chi)
- for (int i = 0; i < WBUF_K; ++i)
- for (int j = 0; j < WBUF_K; ++j)
- w_buf[cho][chi][i][j] = weight_array[cal_index_4d(cho, chi, i, j, CHin, K, K)];
- }
- void load_input_buf(int row_start, int row_len,
- int col_start, int col_len,
- d_type in_buf[INBUF_CH][INBUF_ROW][INBUF_COL],
- d_type* in_array)
- {
- for (int chi = 0; chi < CHin; ++chi)
- {
- for (int row_in_bias = 0, row_in = row_start * S; row_in_bias < row_len; ++row_in_bias, ++row_in)
- {
- for (int col_in_bias = 0, col_in = col_start * S; col_in_bias < col_len; ++col_in_bias, ++col_in)
- {
- in_buf[chi][row_in_bias][col_in_bias] = in_array[cal_index_3d(chi, row_in, col_in, R_in, C_in)];
- }
- }
- }
- }
- void clear_out_buf(d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL])
- {
- for (int row = 0; row < OUTBUF_ROW; ++row)
- for (int col = 0; col < OUTBUF_COL; ++col)
- for (int cho = 0; cho < OUTBUF_CH; ++cho)
- out_buf[cho][row][col] = 0;
- }
- void store_out_buf(int row_start, int row_len, int col_start, int col_len,
- d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL],
- d_type* out_array)
- {
- for (int cho = 0; cho < CHout; ++cho)
- for (int row_out_bias = 0, row_out = row_start; row_out_bias < row_len; ++row_out_bias, ++row_out)
- for (int col_out_bias = 0, col_out = col_start; col_out_bias < col_len; ++col_out_bias, ++col_out)
- out_array[cal_index_3d(cho, row_out, col_out, R_out, C_out)] = out_buf[cho][row_out_bias][col_out_bias];
- }
- void winograd(int to, int ti,
- d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL],
- d_type in_buf[INBUF_CH][INBUF_ROW][INBUF_COL],
- d_type w_buf[WBUF_CHO][WBUF_CHI][WBUF_K][WBUF_K], int flag)
- {
- d_type uv[6][6], u[6][6], t[6][6], tmp = 0;
- // U = B^T Z B
- for (int i = 0; i < 6; i++)
- {
- for (int j = 0; j < 6; j++)
- {
- tmp = 0;
- for (int k = 0; k < 6; k++)
- {
- tmp += BT[i][k] * in_buf[ti][k][j];
- }
- t[i][j] = tmp;
- }
- }
- for (int i = 0; i < 6; i++)
- {
- for (int j = 0; j < 6; j++)
- {
- tmp = 0;
- for (int k = 0; k < 6; k++)
- {
- tmp += t[i][k] * B[k][j];
- }
- u[i][j] = tmp;
- }
- }
- for (int i = 0; i < 6; i++)
- {
- for (int j = 0; j < 6; j++)
- {
- uv[i][j] = u[i][j] * w_buf[to][ti][i][j];
- }
- }
- // Y = A^T UV A
- for (int i = 0; i < 4; i++)
- {
- for (int j = 0; j < 6; j++)
- {
- tmp = 0;
- for (int k = 0; k < 6; k++)
- {
- tmp += AT[i][k] * uv[k][j];
- }
- t[i][j] = tmp;
- }
- }
- for (int i = 0; i < 4; i++)
- {
- for (int j = 0; j < 4; j++)
- {
- tmp = 0;
- for (int k = 0; k < 6; k++)
- {
- tmp += uv[i][k] * A[k][j];
- }
- out_buf[to][i][j] += tmp;
- }
- }
- if (flag == 1)
- {
- printf("**********************ti = %d**************************\n", ti);
- for (int i = 0; i < 6; i++)
- {
- for (int j = 0; j < 6; j++)
- {
- printf("in_buf[%d][%d][%d] is %f\n", ti, i, j, in_buf[ti][i][j]);
- }
- }
- printf("-------------------------------------------------------\n");
- for (int i = 0; i < 6; i++)
- {
- for (int j = 0; j < 6; j++)
- {
- printf("w_buf[0][%d][%d][%d] is %f\n", ti, i, j, w_buf[0][ti][i][j]);
- }
- }
- printf("-------------------------------------------------------\n");
- for (int i = 0; i < 4; i++)
- {
- for (int j = 0; j < 4; j++)
- {
- printf("out_buf[0][%d][%d] is %f\n", i, j, out_buf[0][i][j]);
- }
- }
- }
- }
- void cnn(d_type* In, d_type* Out, d_type* W, int* Parameter)
- {
- /*
- In : Input feature map, CHin*R*C
- Out : Output feature map, CHout*Rout*Cout
- W : weights, CHout*CHin*Kr*Kc
- Parameter: CHout|CHin|R|C|K|S
- */
- d_type out_buf[OUTBUF_CH][OUTBUF_ROW][OUTBUF_COL];
- d_type in_buf[INBUF_CH][INBUF_ROW][INBUF_COL];
- d_type w_buf[WBUF_CHO][WBUF_CHI][WBUF_K][WBUF_K];
- load_weight(w_buf, W); // Load all weight
- OUTPUT_ROW:
- for (int row = 0; row < R_out; row += OUTBUF_ROW) {
- row = (row > 122) ? 122 : row;
- OUTPUT_COLOMN:
- for (int col = 0; col < C_out; col += OUTBUF_COL) {
- col = (col > 122) ? 122 : col;
- load_input_buf(row, INBUF_ROW, col, INBUF_COL, in_buf, In);
- clear_out_buf(out_buf);
- OUTPUT_CHANNEL:
- for (int cho = 0; cho < CHout; cho += Tchout) {
- INPUT_CHANNEL:
- for (int chi = 0; chi < CHin; chi += Tchin) {
- TILED_OUTPUT_CHANNEL:
- for (int to_bias = 0; to_bias < min(CHout - cho, Tchout); ++to_bias) {
- TILED_INPUT_CHANNEL:
- for (int ti_bias = 0; ti_bias < min(CHin - chi, Tchin); ++ti_bias) {
- int flag = 0;
- if (row == 0 && col == 0 && cho + to_bias == 0) {
- flag = 1;
- }
- winograd(cho + to_bias, chi + ti_bias, out_buf, in_buf, w_buf, flag);
- }
- }
- }
- }
- if (row == 0 && col == 0) {
- for (int i = 0; i < 4; i++)
- {
- for (int j = 0; j < 4; j++)
- {
- printf("out_buf[%d][%d] is %f\n", i, j, out_buf[0][i][j]);
- }
- }
- }
- store_out_buf(row, OUTBUF_ROW, col, OUTBUF_COL, out_buf, Out);
- }
- }
- }
- int main() {
- FILE* fin;
- FILE* fout;
- FILE* fw;
- errno_t err;
- err = fopen_s(&fin, "D:\\CourseDocument\\ComputerArchitecturePractice\\Lab7\\Lab7\\dat\\sample_input.dat", "r");
- err = fopen_s(&fout, "D:\\CourseDocument\\ComputerArchitecturePractice\\Lab7\\Lab7\\dat\\sample_out.dat", "r");
- err = fopen_s(&fw, "D:\\CourseDocument\\ComputerArchitecturePractice\\Lab7\\Lab7\\dat\\sample_weight.dat", "r");
- for (int i = 0; i < CHin; i++)
- {
- for (int j = 0; j < R_in; j++)
- {
- for (int x = 0; x < C_in; x++)
- {
- fscanf_s(fin, "%f", &In[i][j][x]);
- }
- }
- }
- for (int i = 0; i < CHout; i++)
- {
- for (int j = 0; j < CHin; j++)
- {
- for (int x = 0; x < K; x++)
- {
- for (int y = 0; y < K; y++)
- {
- fscanf_s(fw, "%f\n", &W[i][j][x][y]);
- }
- }
- }
- }
- for (int i = 0; i < CHout; i++)
- {
- for (int j = 0; j < R_out; j++)
- {
- for (int x = 0; x < C_out; x++)
- {
- Out[i][j][x] = 0;
- }
- }
- }
- int Parameter[] = { CHin, CHout, R_in, C_in, K, S };
- cnn(In[0][0], Out[0][0], W[0][0][0], Parameter);
- for (int i = 0; i < CHout; i++)
- {
- for (int j = 0; j < R_out; j++)
- {
- for (int x = 0; x < C_out; x++)
- {
- fscanf_s(fout, "%f\n", &Gold[i][j][x]);
- }
- }
- }
- float eps1 = 0.01;
- int flag = 0;
- for (int i = 0; i < CHout; i++)
- {
- for (int j = 0; j < R_out; j++)
- {
- for (int x = 0; x < C_out; x++)
- {
- if (fabs(Gold[i][j][x] - Out[i][j][x]) > eps1)
- {
- printf("Not Match!\n");
- printf("Gold is %f\tOut is %f\n", Gold[i][j][x], Out[i][j][x]);
- flag += 1;
- if (flag == 10) {
- fclose(fin);
- fclose(fout);
- fclose(fw);
- return 0;
- }
- }
- }
- }
- }
- printf("Match!\n");
- fclose(fin);
- fclose(fout);
- fclose(fw);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement