Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <stdio.h>
- #include <math.h>
- #include "mkldnn.hpp"
- using namespace mkldnn;
- void init_data(float *dat, int size) {
- for (int i = 0; i < size; ++i)
- dat[i] = 1.f + 2.f * sinf(0.2 * i);
- }
- void exec_bnrm_fwd() {
- engine eng(engine::cpu, 0);
- const int N = 4, C = 4, H = 27, W = 27;
- const int sz = N * C * H * W;
- float *src = (float *)malloc(sizeof(float) * sz);
- init_data(src, sz); // ideally mean ~= 1.f, var ~= 2.f
- float *dst = (float *)malloc(sizeof(float) * sz);
- float *mean = (float *)malloc(sizeof(float) * C);
- float *var = (float *)malloc(sizeof(float) * C);
- memory::desc data_desc({N, C, H, W}, memory::f32, memory::nchw);
- memory src_mem({data_desc, eng}, src);
- memory dst_mem({data_desc, eng}, dst);
- memory::desc stat_desc({C}, memory::f32, memory::x);
- memory mean_mem({stat_desc, eng}, mean);
- memory var_mem({stat_desc, eng}, var);
- unsigned flags = 0;
- // set flags for different flavors (use | to combine flags)
- // use_global_stats -- do not compute mean and variance in the primitive, user has to provide them
- // use_scale_shift -- in addition to batch norm also scale and shift the result
- batch_normalization_forward::desc bnrm_fwd_d(
- prop_kind::forward_training, // might be forward_inference, backward, backward_data
- data_desc, // data descriptor (i.e. sizes, data type, and layout)
- 0.001f, // eps
- flags);
- batch_normalization_forward::primitive_desc bnrm_fwd_pd(bnrm_fwd_d, eng);
- batch_normalization_forward bnrm_fwd(bnrm_fwd_pd,
- src_mem, dst_mem, mean_mem, var_mem);
- stream(stream::kind::eager).submit({bnrm_fwd}).wait(); // execute bnrm
- for (int c = 0; c < C; ++c) {
- printf("[%d] mean:%f var:%f\n", c, mean[c], var[c]);
- }
- }
- int main() {
- exec_bnrm_fwd();
- return 0;
- }
- // $ ( g++ example_bnrm_fwd.cpp -lmkldnn -lm && ./a.out )
- // [0] mean:0.994635 var:1.997512
- // [1] mean:1.000734 var:2.003489
- // [2] mean:1.005777 var:1.996543
- // [3] mean:1.002504 var:2.002276
Add Comment
Please, Sign In to add comment