Advertisement
Guest User

Untitled

a guest
Jan 10th, 2024
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 19.18 KB | None | 0 0
  1. /******************************************************************************
  2. *
  3. * Copyright (C) 2009 - 2014 Xilinx, Inc.  All rights reserved.
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in
  13. * all copies or substantial portions of the Software.
  14. *
  15. * Use of the Software is limited solely to applications:
  16. * (a) running on a Xilinx device, or
  17. * (b) that interact with a Xilinx device through a bus or interconnect.
  18. *
  19. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  20. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  21. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  22. * XILINX  BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
  23. * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF
  24. * OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  25. * SOFTWARE.
  26. *
  27. * Except as contained in this notice, the name of the Xilinx shall not be used
  28. * in advertising or otherwise to promote the sale, use or other dealings in
  29. * this Software without prior written authorization from Xilinx.
  30. *
  31. ******************************************************************************/
  32.  
  33. /*
  34.  * helloworld.c: simple test application
  35.  *
  36.  * This application configures UART 16550 to baud rate 9600.
  37.  * PS7 UART (Zynq) is not initialized by this application, since
  38.  * bootrom/bsp configures it to baud rate 115200
  39.  *
  40.  * ------------------------------------------------
  41.  * | UART TYPE   BAUD RATE                        |
  42.  * ------------------------------------------------
  43.  *   uartns550   9600
  44.  *   uartlite    Configurable only in HW design
  45.  *   ps7_uart    115200 (configured by bootrom/bsp)
  46.  */
  47.  
  48. #include <stdio.h>
  49. #include <stdlib.h>
  50. #include "platform.h"
  51. #include "xil_printf.h"
  52. #include "xparameters.h"
  53. #include "xuartlite.h"
  54. #include "hw_spec.h"
  55.  
  56. #define VTA_FETCH_ADDR XPAR_VTA_FETCH_0_S_AXI_CONTROL_BASEADDR
  57. #define VTA_LOAD_ADDR XPAR_VTA_LOAD_0_S_AXI_CONTROL_BASEADDR
  58. #define VTA_COMPUTE_ADDR XPAR_VTA_COMPUTE_0_S_AXI_CONTROL_BASEADDR
  59. #define VTA_STORE_ADDR XPAR_VTA_STORE_0_S_AXI_CONTROL_BASEADDR
  60.  
  61. /*! \brief VTA configuration register start value */
  62. #define VTA_START 0x1
  63. /*! \brief VTA configuration register auto-restart value */
  64. #define VTA_AUTORESTART 0x81
  65. /*! \brief VTA configuration register done value */
  66. #define VTA_DONE 0x1
  67.  
  68. typedef enum {false, true} bool;
  69.  
  70. typedef uint32_t uop_T;
  71. typedef int8_t wgt_T;
  72. typedef int8_t inp_T;
  73. typedef int8_t out_T;
  74. typedef int32_t acc_T;
  75.  
  76.  
  77. void *VTAMapRegister(uint32_t addr) {
  78.     return (void *)addr;
  79. }
  80.  
  81. void VTAWriteMappedReg(void *p, uint32_t offset, uint32_t data) {
  82.     *(uint32_t *)(((uint8_t *)p) + offset) = data;
  83. }
  84.  
  85. uint32_t VTAReadMappedReg(void* base_addr, uint32_t offset) {
  86.     return *(uint32_t *)base_addr;
  87. }
  88.  
  89. void vta(uint32_t insn_count, VTAGenericInsn *insns, VTAUop *uops, uint32_t *inputs, uint32_t *weights, uint32_t *biases, uint32_t *outputs) {
  90.  
  91.     // Get VTA handles
  92.     void* vta_fetch_handle = VTAMapRegister(VTA_FETCH_ADDR);
  93.     void* vta_load_handle = VTAMapRegister(VTA_LOAD_ADDR);
  94.     void* vta_compute_handle = VTAMapRegister(VTA_COMPUTE_ADDR);
  95.     void* vta_store_handle = VTAMapRegister(VTA_STORE_ADDR);
  96.  
  97.     VTAWriteMappedReg(vta_fetch_handle, VTA_FETCH_INSN_COUNT_OFFSET, insn_count);
  98.     if (insns) VTAWriteMappedReg(vta_fetch_handle, VTA_FETCH_INSN_ADDR_OFFSET, insns);
  99.     if (inputs) VTAWriteMappedReg(vta_load_handle, VTA_LOAD_INP_ADDR_OFFSET, inputs);
  100.     if (weights) VTAWriteMappedReg(vta_load_handle, VTA_LOAD_WGT_ADDR_OFFSET, weights);
  101.     if (uops) VTAWriteMappedReg(vta_compute_handle, VTA_COMPUTE_UOP_ADDR_OFFSET, uops);
  102.     if (biases) VTAWriteMappedReg(vta_compute_handle, VTA_COMPUTE_BIAS_ADDR_OFFSET, biases);
  103.     if (outputs) VTAWriteMappedReg(vta_store_handle, VTA_STORE_OUT_ADDR_OFFSET, outputs);
  104.  
  105.     // VTA start
  106.     VTAWriteMappedReg(vta_fetch_handle, 0x0, 0x1);
  107.     VTAWriteMappedReg(vta_load_handle, 0x0, 0x81);
  108.     VTAWriteMappedReg(vta_compute_handle, 0x0, 0x81);
  109.     VTAWriteMappedReg(vta_store_handle, 0x0, 0x81);
  110.  
  111.     int flag = 0, t = 0;
  112.     for (t = 0; t < 10000000; ++t) {
  113.         flag = VTAReadMappedReg(vta_compute_handle, VTA_COMPUTE_DONE_RD_OFFSET);
  114.         if (flag & VTA_DONE) break;
  115.     }
  116.  
  117.     if (t == 10000000) {
  118.         xil_printf("\tWARNING: VTA TIMEOUT!!!!\n");
  119.     } else {
  120.         xil_printf("INFO - FPGA Finished!\n");
  121.     }
  122. }
  123.  
  124. const char* getOpcodeString(int opcode, bool use_imm) {
  125.   // Returns string name
  126.   if (opcode == VTA_ALU_OPCODE_MIN) {
  127.     if (use_imm) {
  128.       return "min imm";
  129.     } else {
  130.       return "min";
  131.     }
  132.   } else if (opcode == VTA_ALU_OPCODE_MAX) {
  133.     if (use_imm) {
  134.       return "max imm";
  135.     } else {
  136.       return "max";
  137.     }
  138.   } else if (opcode == VTA_ALU_OPCODE_ADD) {
  139.     if (use_imm) {
  140.       return "add imm";
  141.     } else {
  142.       return "add";
  143.     }
  144.   } else if (opcode == VTA_ALU_OPCODE_SHR) {
  145.     return "shr";
  146.   }
  147.   // else if (opcode == VTA_ALU_OPCODE_MUL) {
  148.   //   return "mul";
  149.   // }
  150.   return "unknown op";
  151. }
  152.  
  153. void assert(uint8_t o) {
  154.     if (!o) {
  155.         xil_printf("ASSERT FAILED!!\r\n");
  156.         while (1);
  157.     }
  158. }
  159.  
  160. void * allocBuffer(size_t num_bytes) {
  161.     return malloc(num_bytes);
  162. }
  163.  
  164. VTAGenericInsn get1DLoadStoreInsn(int opcode, int type, int sram_offset, int dram_offset, int size,
  165.     int pop_prev_dep, int pop_next_dep, int push_prev_dep, int push_next_dep) {
  166.     // Converter
  167.     union VTAInsn converter;
  168.     // Memory instruction initialization
  169.     VTAMemInsn insn = {};
  170.     insn.opcode = opcode;
  171.     insn.pop_prev_dep = pop_prev_dep;
  172.     insn.pop_next_dep = pop_next_dep;
  173.     insn.push_prev_dep = push_prev_dep;
  174.     insn.push_next_dep = push_next_dep;
  175.     insn.memory_type = type;
  176.     insn.sram_base = sram_offset;
  177.     insn.dram_base = dram_offset;
  178.     insn.y_size = 1;
  179.     insn.x_size = size;
  180.     insn.x_stride = size;
  181.     insn.y_pad_0 = 0;
  182.     insn.y_pad_1 = 0;
  183.     insn.x_pad_0 = 0;
  184.     insn.x_pad_1 = 0;
  185.     converter.mem = insn;
  186.     return converter.generic;
  187. }
  188.  
  189. VTAGenericInsn get2DLoadStoreInsn(int opcode, int type, int sram_offset, int dram_offset,
  190.     int y_size, int x_size, int x_stride, int y_pad, int x_pad, int pop_prev_dep, int pop_next_dep,
  191.     int push_prev_dep, int push_next_dep) {
  192.   // Converter
  193.   union VTAInsn converter;
  194.   // Memory instruction initialization
  195.   VTAMemInsn insn = {};
  196.   insn.opcode = opcode;
  197.   insn.pop_prev_dep = pop_prev_dep;
  198.   insn.pop_next_dep = pop_next_dep;
  199.   insn.push_prev_dep = push_prev_dep;
  200.   insn.push_next_dep = push_next_dep;
  201.   insn.memory_type = type;
  202.   insn.sram_base = sram_offset;
  203.   insn.dram_base = dram_offset;
  204.   insn.y_size = y_size;
  205.   insn.x_size = x_size;
  206.   insn.x_stride = x_stride;
  207.   insn.y_pad_0 = y_pad;
  208.   insn.y_pad_1 = y_pad;
  209.   insn.x_pad_0 = x_pad;
  210.   insn.x_pad_1 = x_pad;
  211.   converter.mem = insn;
  212.   return converter.generic;
  213. }
  214.  
  215. VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bool uop_compression,
  216.     int pop_prev_dep, int pop_next_dep, int push_prev_dep, int push_next_dep) {
  217.   // Converter
  218.   union VTAInsn converter;
  219.   // Memory instruction initialization
  220.   VTAAluInsn insn = {};
  221.   insn.opcode = VTA_OPCODE_ALU;
  222.   insn.pop_prev_dep = pop_prev_dep;
  223.   insn.pop_next_dep = pop_next_dep;
  224.   insn.push_prev_dep = push_prev_dep;
  225.   insn.push_next_dep = push_next_dep;
  226.   insn.reset_reg = false;
  227.   if (!uop_compression) {
  228.     insn.uop_bgn = 0;
  229.     insn.uop_end = vector_size;
  230.     insn.iter_out = 1;
  231.     insn.iter_in = 1;
  232.     insn.dst_factor_out = 0;
  233.     insn.src_factor_out = 0;
  234.     insn.dst_factor_in = 0;
  235.     insn.src_factor_in = 0;
  236.     insn.alu_opcode = opcode;
  237.     insn.use_imm = use_imm;
  238.     insn.imm = imm;
  239.   } else {
  240.     insn.uop_bgn = 0;
  241.     insn.uop_end = 1;
  242.     insn.iter_out = 1;
  243.     insn.iter_in = vector_size;
  244.     insn.dst_factor_out = 0;
  245.     insn.src_factor_out = 0;
  246.     insn.dst_factor_in = 1;
  247.     insn.src_factor_in = 1;
  248.     insn.alu_opcode = opcode;
  249.     insn.use_imm = use_imm;
  250.     insn.imm = imm;
  251.   }
  252.   converter.alu = insn;
  253.   return converter.generic;
  254. }
  255.  
  256. VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) {
  257.   // Converter
  258.   union VTAInsn converter;
  259.   // GEMM instruction initialization
  260.   VTAGemInsn insn;
  261.   insn.opcode = VTA_OPCODE_FINISH;
  262.   insn.pop_prev_dep = pop_prev;
  263.   insn.pop_next_dep = pop_next;
  264.   insn.push_prev_dep = 0;
  265.   insn.push_next_dep = 0;
  266.   insn.reset_reg = false;
  267.   insn.uop_bgn = 0;
  268.   insn.uop_end = 0;
  269.   insn.iter_out = 0;
  270.   insn.iter_in = 0;
  271.   insn.dst_factor_out = 0;
  272.   insn.src_factor_out = 0;
  273.   insn.wgt_factor_out = 0;
  274.   insn.dst_factor_in = 0;
  275.   insn.src_factor_in = 0;
  276.   insn.wgt_factor_in = 0;
  277.   converter.gemm = insn;
  278.   return converter.generic;
  279. }
  280.  
  281. VTAUop * getMapALUUops(int vector_size, bool uop_compression) {
  282.     // Derive the total uop size
  283.     int uop_size = (uop_compression) ? 1 : vector_size;
  284.  
  285.     // Allocate buffer
  286.     VTAUop *uop_buf = (VTAUop *)(malloc(sizeof(VTAUop) * uop_size));
  287.  
  288.     if (!uop_compression) {
  289.     for (int i = 0; i < vector_size; i++) {
  290.       uop_buf[i].dst_idx = i;
  291.       uop_buf[i].src_idx = vector_size + i;
  292.     }
  293.     } else {
  294.     uop_buf[0].dst_idx = 0;
  295.     uop_buf[0].src_idx = vector_size;
  296.     }
  297.  
  298.     return uop_buf;
  299. }
  300.  
  301. uint32_t globalSeed;
  302.  
  303. acc_T** alloc2dArray_accT(int rows, int cols) {
  304.     acc_T**array = (acc_T **)(malloc(sizeof(acc_T *) * rows));
  305.       for (int i = 0; i < rows; i++) {
  306.         array[i] = (acc_T *)(malloc(sizeof(acc_T) * cols));
  307.       }
  308.       return array;
  309. }
  310.  
  311. out_T** alloc2dArray_outT(int rows, int cols) {
  312.     out_T**array = (out_T **)(malloc(sizeof(out_T *) * rows));
  313.       for (int i = 0; i < rows; i++) {
  314.         array[i] = (out_T *)(malloc(sizeof(out_T) * cols));
  315.       }
  316.       return array;
  317. }
  318.  
  319. void packBuffer_uint32_32_accT_VTAACCWIDTH(uint32_t *dst, acc_T **src, int y_size, int x_size, int y_block, int x_block) {
  320.   assert((VTA_ACC_WIDTH * x_block * y_block) % 32  == 0);
  321.   assert(32 <= 64);
  322.   int buffer_idx = 0;
  323.   int ratio = 32 / VTA_ACC_WIDTH;
  324.   long long int mask = (1ULL << VTA_ACC_WIDTH) - 1;
  325.   uint32_t tmp = 0;
  326.   for (int i = 0; i < y_size / y_block; i++) {
  327.     for (int j = 0; j < x_size / x_block; j++) {
  328.       for (int k = 0; k < y_block; k++) {
  329.         for (int l = 0; l < x_block; l++) {
  330.           int block_idx = l + k * x_block;
  331.           tmp |= (src[i * y_block + k][j * x_block + l] & mask) << ((block_idx % ratio) * VTA_ACC_WIDTH);
  332.           // When tmp is packed, write to destination array
  333.           if (block_idx % ratio == ratio - 1) {
  334.             dst[buffer_idx++] = tmp;
  335.             tmp = 0;
  336.           }
  337.         }
  338.       }
  339.     }
  340.   }
  341. }
  342.  
  343. void unpackBuffer(out_T **dst, uint32_t *src, int y_size, int x_size, int y_block, int x_block) {
  344.   assert((VTA_OUT_WIDTH * x_block * y_block) % 32 == 0);
  345.   int buffer_idx = 0;
  346.   long long int mask = (1ULL << VTA_OUT_WIDTH) - 1;
  347.   int ratio = 32 / VTA_OUT_WIDTH;
  348.   for (int i = 0; i < y_size / y_block; i++) {
  349.     for (int j = 0; j < x_size / x_block; j++) {
  350.       for (int k = 0; k < y_block; k++) {
  351.         for (int l = 0; l < x_block; l++) {
  352.           int block_idx = l + k * x_block;
  353.           dst[i * y_block + k][j * x_block + l] = (src[buffer_idx] >> ((block_idx % ratio) * VTA_OUT_WIDTH)) & mask;
  354.           if (block_idx % ratio == ratio - 1) {
  355.             buffer_idx++;
  356.           }
  357.         }
  358.       }
  359.     }
  360.   }
  361. }
  362.  
  363. int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_compression) {
  364.     // Some assertions
  365.     assert(batch % VTA_BATCH == 0);
  366.     assert(vector_size % VTA_BLOCK_OUT == 0);
  367.     printf("=====================================================================================\n");
  368.     printf("INFO - ALU test of %s: batch=%d, vector_size=%d, uop_compression=%d\n", getOpcodeString(opcode, use_imm), batch, vector_size, uop_compression);
  369.  
  370.     // Instruction count
  371.     int ins_size = 3 * batch / VTA_BATCH + 2;
  372.     // Micro op count
  373.     int uop_size = uop_compression ? 1 : vector_size / VTA_BLOCK_OUT;
  374.     // Input/output elements in each transfer
  375.     int tx_size = vector_size / VTA_BLOCK_OUT;
  376.     // Number of input sets to be generated
  377.     int input_sets = (use_imm) ? 1 : 2;
  378.     // Make sure we don't exceed buffer bounds
  379.     assert(uop_size <= VTA_UOP_BUFF_DEPTH);
  380.     assert(tx_size * input_sets <= VTA_ACC_BUFF_DEPTH);
  381.  
  382.     // Immediate values
  383.     acc_T *immediate = (acc_T *)(malloc(sizeof(acc_T) * batch / VTA_BATCH));
  384.     for (int b = 0; b < batch / VTA_BATCH; b++) {
  385.         if (opcode == VTA_ALU_OPCODE_MIN) {
  386.             immediate[b] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_ALUOP_IMM_BIT_WIDTH - 1)) - (1LL << (VTA_ALUOP_IMM_BIT_WIDTH - 2)));
  387.         } else if (opcode == VTA_ALU_OPCODE_MAX) {
  388.             immediate[b] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_ALUOP_IMM_BIT_WIDTH - 1)) - (1LL << (VTA_ALUOP_IMM_BIT_WIDTH - 2)));
  389.         } else if (opcode == VTA_ALU_OPCODE_ADD) {
  390.             immediate[b] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_ALUOP_IMM_BIT_WIDTH - 1)) - (1LL << (VTA_ALUOP_IMM_BIT_WIDTH - 2)));
  391.         } else if (opcode == VTA_ALU_OPCODE_SHR) {
  392.             immediate[b] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_SHR_ARG_BIT_WIDTH - 1)) - (1LL << (VTA_SHR_ARG_BIT_WIDTH - 2)));
  393.         }
  394.     }
  395.  
  396.     // Initialize instructions
  397.     VTAGenericInsn *insn_buf = (VTAGenericInsn *)(allocBuffer(sizeof(VTAGenericInsn) * ins_size));
  398.     int insn_idx = 0;
  399.     insn_buf[insn_idx++] = get1DLoadStoreInsn(VTA_OPCODE_LOAD, VTA_MEM_ID_UOP, 0, 0, uop_size, 0, 0, 0, 0);
  400.  
  401.     for (int b = 0; b < batch; b += VTA_BATCH) {
  402.         insn_buf[insn_idx++] = get2DLoadStoreInsn(
  403.                 VTA_OPCODE_LOAD,                   // opcode
  404.                 VTA_MEM_ID_ACC,                    // vector size
  405.                 0,                                 // sram offset
  406.                 b / VTA_BATCH * tx_size * input_sets,  // dram offset
  407.                 1,                                 // y size
  408.                 tx_size * input_sets,              // x size
  409.                 tx_size * input_sets,              // x stride
  410.                 0,                                 // y pad
  411.                 0,                                 // x pad
  412.                 0,                                 // pop prev dep
  413.                 b > 0,                             // pop next dep
  414.                 0,                                 // push prev dep
  415.                 0);                                // push next dep
  416.  
  417.         insn_buf[insn_idx++] = getALUInsn(
  418.                 opcode,                            // opcode
  419.                 tx_size,                           // vector size
  420.                 use_imm,                           // use imm
  421.                 immediate[b / VTA_BATCH],          // imm
  422.                 uop_compression,                   // uop compression
  423.                 0,                                 // pop prev dep
  424.                 0,                                 // pop next dep
  425.                 0,                                 // push prev dep
  426.                 1);                                // push next dep
  427.  
  428.         insn_buf[insn_idx++] = get2DLoadStoreInsn(
  429.                 VTA_OPCODE_STORE,                  // opcode
  430.                 VTA_MEM_ID_OUT,                    // vector size
  431.                 0,                                 // sram offset
  432.                 b / VTA_BATCH * tx_size,           // dram offset
  433.                 1,                                 // y size
  434.                 tx_size,                           // x size
  435.                 tx_size,                           // x stride
  436.                 0,                                 // y pad
  437.                 0,                                 // x pad
  438.                 1,                                 // pop prev dep
  439.                 0,                                 // pop next dep
  440.                 1,                                 // push prev dep
  441.                 0);                                // push next dep
  442.     }
  443.     // Finish
  444.     insn_buf[insn_idx++] = getFinishInsn(0, 1);
  445.     // Prepare the uop buffer
  446.     VTAUop * uop_buf = getMapALUUops(tx_size, uop_compression);
  447.  
  448.     // Initialize the input/output data
  449.     acc_T **inputs = alloc2dArray_accT(batch, vector_size * input_sets);
  450.     for (int i = 0; i < batch; i++) {
  451.         for (int j = 0; j < vector_size * input_sets; j++) {
  452.             if (opcode == VTA_ALU_OPCODE_MIN) {
  453.                 inputs[i][j] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
  454.             } else if (opcode == VTA_ALU_OPCODE_MAX) {
  455.                 inputs[i][j] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
  456.             } else if (opcode == VTA_ALU_OPCODE_ADD) {
  457.                 inputs[i][j] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 2)) - (1LL << (VTA_INP_WIDTH - 3)));
  458.             } else if (opcode == VTA_ALU_OPCODE_SHR) {
  459.                 inputs[i][j] = (acc_T)(rand_r(&globalSeed) % (1LL << (VTA_SHR_ARG_BIT_WIDTH - 1)) - (1LL << (VTA_SHR_ARG_BIT_WIDTH - 2)));
  460.             }
  461.         }
  462.     }
  463.  
  464.     // Compute reference output
  465.     out_T **outputs_ref = alloc2dArray_outT(batch, vector_size);
  466.     for (int i = 0; i < batch; i++) {
  467.         for (int j = 0; j < vector_size; j++) {
  468.             acc_T out_val = 0;
  469.             acc_T imm_val = immediate[i / VTA_BATCH];
  470.             acc_T src_val = inputs[i][j + vector_size];
  471.             if (opcode == VTA_ALU_OPCODE_MIN) {
  472.                 if (!use_imm) {
  473.                     out_val = inputs[i][j] < src_val ? inputs[i][j] : src_val;
  474.                 } else {
  475.                     out_val = inputs[i][j] < imm_val ? inputs[i][j] : imm_val;
  476.                 }
  477.             } else if (opcode == VTA_ALU_OPCODE_MAX) {
  478.                 if (!use_imm) {
  479.                     out_val = inputs[i][j] > src_val ? inputs[i][j] : src_val;
  480.                 } else {
  481.                     out_val = inputs[i][j] > imm_val ? inputs[i][j] : imm_val;
  482.                 }
  483.             } else if (opcode == VTA_ALU_OPCODE_ADD) {
  484.                 if (!use_imm) {
  485.                     out_val = inputs[i][j] + src_val;
  486.                 } else {
  487.                     out_val = inputs[i][j] + imm_val;
  488.                 }
  489.             } else if (opcode == VTA_ALU_OPCODE_SHR) {
  490.                 if (!use_imm) {
  491.                     if (src_val >= 0) {
  492.                         out_val = inputs[i][j] >> src_val;
  493.                     } else {
  494.                         out_val = inputs[i][j] << (0 - src_val);
  495.                     }
  496.                 } else {
  497.                     if (imm_val >= 0) {
  498.                         out_val = inputs[i][j] >> imm_val;
  499.                     } else {
  500.                         out_val = inputs[i][j] << (0 - imm_val);
  501.                     }
  502.                 }
  503.             }
  504.             outputs_ref[i][j] = (out_T) out_val;
  505.         }
  506.     }
  507.  
  508.     // Pack input buffer
  509.     uint32_t *bias_buf = (uint32_t *)(allocBuffer(VTA_ACC_ELEM_BYTES * batch * tx_size * input_sets));
  510.     packBuffer_uint32_32_accT_VTAACCWIDTH(bias_buf, inputs, batch, vector_size * input_sets, VTA_BATCH, VTA_BLOCK_OUT);
  511.  
  512.     // Prepare output buffer
  513.     uint32_t *output_buf = (uint32_t *)(allocBuffer(VTA_OUT_ELEM_BYTES * batch * tx_size * input_sets));
  514.  
  515.     // Invoke the VTA
  516.     vta(ins_size, insn_buf, uop_buf, NULL, NULL, bias_buf, output_buf);
  517.  
  518. // Unpack output buffer
  519.     out_T **outputs = alloc2dArray_outT(batch, vector_size);
  520.     unpackBuffer(outputs, output_buf, batch, vector_size, VTA_BATCH, VTA_BLOCK_OUT);
  521.  
  522.     // Correctness checks
  523.     int err = 0;
  524.     for (int i = 0; i < batch; i++) {
  525.         for (int j = 0; j < vector_size; j++) {
  526.             if (outputs_ref[i][j] != outputs[i][j]) {
  527.                 err++;
  528.                 printf("DEBUG - %d, %d: expected 0x%x but got 0x%x\n", i, j, (int)(outputs_ref[i][j]), (int)(outputs[i][j]));
  529.             }
  530.         }
  531.     }
  532.  
  533.     // Free all allocated arrays
  534.     free(immediate);
  535.  
  536.     if (err == 0) {
  537.         printf("INFO - ALU test successful!\n");
  538.         return 0;
  539.     } else {
  540.         printf("INFO - ALU test failed, got %d errors!\n", err);
  541.         return -1;
  542.     }
  543. }
  544.  
  545. int main()
  546. {
  547.  
  548.     XUartLite UartLite;
  549.  
  550.     init_platform();
  551.     print("Hello World\n\r");
  552.  
  553.     XUartLite_Initialize(&UartLite, XPAR_AXI_UARTLITE_0_DEVICE_ID);
  554.  
  555.     alu_test(VTA_ALU_OPCODE_MAX, true, 16, 128, true);
  556.  
  557.  
  558.  
  559.     print("Successfully ran Hello World application");
  560.     cleanup_platform();
  561.     return 0;
  562. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement