Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <memory>
- #include "mkldnn.hpp"
- using namespace mkldnn;
- using mem_d = memory::desc;
- using mem_pd = memory::primitive_desc;
- void build_net(std::vector<primitive>& net, engine &cpu) {
- float* biases = new float[1];
- float* inputs = new float[1 * 3 * 10 * 10];
- float* weights = new float[1 * 3 * 3 * 3];
- float* results = new float[1 * 1 * 10 * 10];
- memory::dims conv_input_dims = { 1, 3, 10, 10};
- memory::dims conv_weight_dims = {1, 3, 3, 3};
- memory::dims conv_bias_dims = { 1 };
- memory::dims conv_output_dims = {1, 1, 10, 10};
- memory::dims conv_stride = { 1, 1 };
- auto conv_padding = { 1, 1 };
- // Create memory descriptors for the user data (using the actual data layout).
- auto user_input_md = mem_d(
- { conv_input_dims }, memory::data_type::f32, memory::format::nchw);
- auto user_weight_md = mem_d(
- { conv_weight_dims }, memory::data_type::f32, memory::format::oihw);
- auto user_bias_md = mem_d(
- { conv_bias_dims }, memory::data_type::f32, memory::format::x);
- // Create primitive memory descriptors for user data.
- auto user_input_memory_descriptor = mem_pd(user_input_md, cpu);
- auto user_weight_memory_descriptor = mem_pd(user_weight_md, cpu);
- auto user_bias_memory_descriptor = mem_pd(user_bias_md, cpu);
- // Create memory primitives for user input data (input, weights, biases).
- std::shared_ptr<memory> user_input_memory(
- new memory(user_input_memory_descriptor, inputs));
- std::shared_ptr<memory> user_weight_memory(
- new memory(user_weight_memory_descriptor, weights));
- std::shared_ptr<memory> user_bias_memory(
- new memory(user_bias_memory_descriptor, biases));
- // Create memory descriptors for the convolution primitive.
- auto conv_input_md = mem_d(
- { conv_input_dims }, memory::data_type::f32, memory::format::any);
- auto conv_weight_md = mem_d(
- { conv_weight_dims }, memory::data_type::f32, memory::format::any);
- auto conv_bias_md = mem_d(
- { conv_bias_dims }, memory::data_type::f32, memory::format::any);
- auto conv_output_md = mem_d(
- { conv_output_dims }, memory::data_type::f32, memory::format::any);
- // Create the convolution primitive.
- auto conv_desc = convolution_forward::desc(
- prop_kind::forward, algorithm::convolution_direct, conv_input_md,
- conv_weight_md, conv_bias_md, conv_output_md, conv_stride,
- conv_padding, conv_padding, mkldnn::padding_kind::zero);
- auto conv_pd = convolution_forward::primitive_desc(conv_desc, cpu);
- // Check if a data layout transform is required.
- auto conv_input_memory = user_input_memory;
- if (mem_pd(conv_pd.src_primitive_desc()) !=
- user_input_memory_descriptor) {
- conv_input_memory = std::shared_ptr<memory>(
- new memory(conv_pd.src_primitive_desc()));
- /*
- auto conv_reorder_input = std::shared_ptr<reorder>(
- new reorder(*user_input_memory, *conv_input_memory));
- auto conv_reorder_input = reorder(*user_input_memory, *conv_input_memory);
- net.push_back(conv_reorder_input);
- */
- net.emplace_back(reorder(*user_input_memory, *conv_input_memory));
- }
- auto conv_weight_memory = user_weight_memory;
- if (mem_pd(conv_pd.weights_primitive_desc()) !=
- user_weight_memory_descriptor) {
- conv_weight_memory = std::shared_ptr<memory>(
- new memory(conv_pd.src_primitive_desc()));
- /*
- auto conv_reorder_weight = std::shared_ptr<reorder>(
- new reorder(*user_weight_memory, *conv_weight_memory));
- auto conv_reorder_weight =
- reorder(*user_weight_memory, *conv_weight_memory);
- net.push_back(conv_reorder_weight);
- */
- net.emplace_back(reorder(*user_weight_memory, *conv_weight_memory));
- }
- // Create memory primitives for the output.
- auto conv_output_memory = std::shared_ptr<memory>(
- new memory(conv_pd.dst_primitive_desc(), results));
- // Finally, create the convolution primitive.
- /*
- auto conv = std::shared_ptr<convolution_forward>(new convolution_forward(
- conv_pd, *conv_input_memory, *conv_weight_memory, *user_bias_memory,
- *conv_output_memory));
- auto conv = convolution_forward(conv_pd, *conv_input_memory,
- *conv_weight_memory, *user_bias_memory,
- *conv_output_memory);
- net.push_back(conv);
- */
- net.emplace_back(convolution_forward(conv_pd, *conv_input_memory,
- *conv_weight_memory, *user_bias_memory,
- *conv_output_memory));
- }
- int main() {
- std::vector<primitive> net;
- engine cpu(engine::cpu, 0);
- build_net(net, cpu);
- stream(stream::kind::eager).submit(net).wait();
- return 0;
- }
Add Comment
Please, Sign In to add comment