Guest User

Untitled

a guest
Dec 18th, 2018
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.04 KB | None | 0 0
  1. #include <stdio.h>
  2. #include <math.h>
  3.  
  4. #include "mkldnn.hpp"
  5.  
  6. using namespace mkldnn;
  7.  
  8. void init_data(float *dat, int size) {
  9. for (int i = 0; i < size; ++i)
  10. dat[i] = 1.f + 2.f * sinf(0.2 * i);
  11. }
  12.  
  13. void exec_bnrm_fwd() {
  14. engine eng(engine::cpu, 0);
  15.  
  16. const int N = 4, C = 4, H = 27, W = 27;
  17. const int sz = N * C * H * W;
  18.  
  19. float *src = (float *)malloc(sizeof(float) * sz);
  20. init_data(src, sz); // ideally mean ~= 1.f, var ~= 2.f
  21.  
  22. float *dst = (float *)malloc(sizeof(float) * sz);
  23. float *mean = (float *)malloc(sizeof(float) * C);
  24. float *var = (float *)malloc(sizeof(float) * C);
  25.  
  26. memory::desc data_desc({N, C, H, W}, memory::f32, memory::nchw);
  27. memory src_mem({data_desc, eng}, src);
  28. memory dst_mem({data_desc, eng}, dst);
  29.  
  30. memory::desc stat_desc({C}, memory::f32, memory::x);
  31. memory mean_mem({stat_desc, eng}, mean);
  32. memory var_mem({stat_desc, eng}, var);
  33.  
  34. unsigned flags = 0;
  35. // set flags for different flavors (use | to combine flags)
  36. // use_global_stats -- do not compute mean and variance in the primitive, user has to provide them
  37. // use_scale_shift -- in addition to batch norm also scale and shift the result
  38.  
  39. batch_normalization_forward::desc bnrm_fwd_d(
  40. prop_kind::forward_training, // might be forward_inference, backward, backward_data
  41. data_desc, // data descriptor (i.e. sizes, data type, and layout)
  42. 0.001f, // eps
  43. flags);
  44. batch_normalization_forward::primitive_desc bnrm_fwd_pd(bnrm_fwd_d, eng);
  45.  
  46. batch_normalization_forward bnrm_fwd(bnrm_fwd_pd,
  47. src_mem, dst_mem, mean_mem, var_mem);
  48. stream(stream::kind::eager).submit({bnrm_fwd}).wait(); // execute bnrm
  49.  
  50. for (int c = 0; c < C; ++c) {
  51. printf("[%d] mean:%f var:%f\n", c, mean[c], var[c]);
  52. }
  53. }
  54.  
  55. int main() {
  56. exec_bnrm_fwd();
  57. return 0;
  58. }
  59.  
  60. // $ ( g++ example_bnrm_fwd.cpp -lmkldnn -lm && ./a.out )
  61. // [0] mean:0.994635 var:1.997512
  62. // [1] mean:1.000734 var:2.003489
  63. // [2] mean:1.005777 var:1.996543
  64. // [3] mean:1.002504 var:2.002276
Add Comment
Please, Sign In to add comment