Guest User

Untitled

a guest
Dec 13th, 2017
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.88 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <memory>
  4. #include "mkldnn.hpp"
  5.  
  6. using namespace mkldnn;
  7. using mem_d = memory::desc;
  8. using mem_pd = memory::primitive_desc;
  9.  
  10. void build_net(std::vector<primitive>& net, engine &cpu) {
  11. float* biases = new float[1];
  12. float* inputs = new float[1 * 3 * 10 * 10];
  13. float* weights = new float[1 * 3 * 3 * 3];
  14. float* results = new float[1 * 1 * 10 * 10];
  15.  
  16. memory::dims conv_input_dims = { 1, 3, 10, 10};
  17. memory::dims conv_weight_dims = {1, 3, 3, 3};
  18. memory::dims conv_bias_dims = { 1 };
  19. memory::dims conv_output_dims = {1, 1, 10, 10};
  20. memory::dims conv_stride = { 1, 1 };
  21. auto conv_padding = { 1, 1 };
  22.  
  23. // Create memory descriptors for the user data (using the actual data layout).
  24. auto user_input_md = mem_d(
  25. { conv_input_dims }, memory::data_type::f32, memory::format::nchw);
  26. auto user_weight_md = mem_d(
  27. { conv_weight_dims }, memory::data_type::f32, memory::format::oihw);
  28. auto user_bias_md = mem_d(
  29. { conv_bias_dims }, memory::data_type::f32, memory::format::x);
  30.  
  31. // Create primitive memory descriptors for user data.
  32. auto user_input_memory_descriptor = mem_pd(user_input_md, cpu);
  33. auto user_weight_memory_descriptor = mem_pd(user_weight_md, cpu);
  34. auto user_bias_memory_descriptor = mem_pd(user_bias_md, cpu);
  35.  
  36. // Create memory primitives for user input data (input, weights, biases).
  37. std::shared_ptr<memory> user_input_memory(
  38. new memory(user_input_memory_descriptor, inputs));
  39. std::shared_ptr<memory> user_weight_memory(
  40. new memory(user_weight_memory_descriptor, weights));
  41. std::shared_ptr<memory> user_bias_memory(
  42. new memory(user_bias_memory_descriptor, biases));
  43.  
  44. // Create memory descriptors for the convolution primitive.
  45. auto conv_input_md = mem_d(
  46. { conv_input_dims }, memory::data_type::f32, memory::format::any);
  47. auto conv_weight_md = mem_d(
  48. { conv_weight_dims }, memory::data_type::f32, memory::format::any);
  49. auto conv_bias_md = mem_d(
  50. { conv_bias_dims }, memory::data_type::f32, memory::format::any);
  51. auto conv_output_md = mem_d(
  52. { conv_output_dims }, memory::data_type::f32, memory::format::any);
  53.  
  54. // Create the convolution primitive.
  55. auto conv_desc = convolution_forward::desc(
  56. prop_kind::forward, algorithm::convolution_direct, conv_input_md,
  57. conv_weight_md, conv_bias_md, conv_output_md, conv_stride,
  58. conv_padding, conv_padding, mkldnn::padding_kind::zero);
  59. auto conv_pd = convolution_forward::primitive_desc(conv_desc, cpu);
  60.  
  61. // Check if a data layout transform is required.
  62. auto conv_input_memory = user_input_memory;
  63. if (mem_pd(conv_pd.src_primitive_desc()) !=
  64. user_input_memory_descriptor) {
  65. conv_input_memory = std::shared_ptr<memory>(
  66. new memory(conv_pd.src_primitive_desc()));
  67. /*
  68. auto conv_reorder_input = std::shared_ptr<reorder>(
  69. new reorder(*user_input_memory, *conv_input_memory));
  70. auto conv_reorder_input = reorder(*user_input_memory, *conv_input_memory);
  71. net.push_back(conv_reorder_input);
  72. */
  73. net.emplace_back(reorder(*user_input_memory, *conv_input_memory));
  74. }
  75.  
  76. auto conv_weight_memory = user_weight_memory;
  77. if (mem_pd(conv_pd.weights_primitive_desc()) !=
  78. user_weight_memory_descriptor) {
  79. conv_weight_memory = std::shared_ptr<memory>(
  80. new memory(conv_pd.src_primitive_desc()));
  81. /*
  82. auto conv_reorder_weight = std::shared_ptr<reorder>(
  83. new reorder(*user_weight_memory, *conv_weight_memory));
  84. auto conv_reorder_weight =
  85. reorder(*user_weight_memory, *conv_weight_memory);
  86. net.push_back(conv_reorder_weight);
  87. */
  88. net.emplace_back(reorder(*user_weight_memory, *conv_weight_memory));
  89. }
  90.  
  91. // Create memory primitives for the output.
  92. auto conv_output_memory = std::shared_ptr<memory>(
  93. new memory(conv_pd.dst_primitive_desc(), results));
  94.  
  95. // Finally, create the convolution primitive.
  96. /*
  97. auto conv = std::shared_ptr<convolution_forward>(new convolution_forward(
  98. conv_pd, *conv_input_memory, *conv_weight_memory, *user_bias_memory,
  99. *conv_output_memory));
  100. auto conv = convolution_forward(conv_pd, *conv_input_memory,
  101. *conv_weight_memory, *user_bias_memory,
  102. *conv_output_memory);
  103. net.push_back(conv);
  104. */
  105. net.emplace_back(convolution_forward(conv_pd, *conv_input_memory,
  106. *conv_weight_memory, *user_bias_memory,
  107. *conv_output_memory));
  108.  
  109. }
  110.  
  111. int main() {
  112. std::vector<primitive> net;
  113. engine cpu(engine::cpu, 0);
  114. build_net(net, cpu);
  115. stream(stream::kind::eager).submit(net).wait();
  116. return 0;
  117. }
Add Comment
Please, Sign In to add comment