Advertisement
Guest User

Untitled

a guest
Nov 5th, 2017
396
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 9.56 KB | None | 0 0
  1. #include <fstream>
  2. #include <utility>
  3. #include <vector>
  4. #include <iostream>
  5.  
  6. #include <tensorflow/cc/ops/const_op.h>
  7. #include <tensorflow/cc/ops/image_ops.h>
  8. #include <tensorflow/cc/ops/standard_ops.h>
  9. #include <tensorflow/core/framework/graph.pb.h>
  10. #include <tensorflow/core/framework/tensor.h>
  11. #include <tensorflow/core/graph/default_device.h>
  12. #include <tensorflow/core/graph/graph_def_builder.h>
  13. #include <tensorflow/core/lib/core/errors.h>
  14. #include <tensorflow/core/lib/core/stringpiece.h>
  15. #include <tensorflow/core/lib/core/threadpool.h>
  16. #include <tensorflow/core/lib/io/path.h>
  17. #include <tensorflow/core/lib/strings/stringprintf.h>
  18. #include <tensorflow/core/platform/env.h>
  19. #include <tensorflow/core/platform/init_main.h>
  20. #include <tensorflow/core/platform/logging.h>
  21. #include <tensorflow/core/platform/types.h>
  22. #include <tensorflow/core/public/session.h>
  23. #include <tensorflow/core/util/command_line_flags.h>
  24.  
  25. // These are all common classes it's handy to reference with no namespace.
  26. using tensorflow::Flag;
  27. using tensorflow::Tensor;
  28. using tensorflow::Status;
  29. using tensorflow::string;
  30. using tensorflow::int32;
  31.  
  32. using namespace std;
  33.  
  34. namespace{
  35.  
  36. static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
  37.                              Tensor* output) {
  38.     tensorflow::uint64 file_size = 0;
  39.     TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
  40.  
  41.     string contents;
  42.     contents.resize(file_size);
  43.  
  44.     std::unique_ptr<tensorflow::RandomAccessFile> file;
  45.     TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
  46.  
  47.     tensorflow::StringPiece data;
  48.     TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
  49.     if (data.size() != file_size) {
  50.         return tensorflow::errors::DataLoss("Truncated read of '", filename,
  51.                                             "' expected ", file_size, " got ",
  52.                                             data.size());
  53.     }
  54.     output->scalar<string>()() = data.ToString();
  55.     return Status::OK();
  56. }
  57.  
  58. // Given an image file name, read in the data, try to decode it as an image,
  59. // resize it to the requested size, and then scale the values as desired.
  60. Status ReadTensorFromImageFile(const string& file_name,
  61.                                std::vector<Tensor>* out_tensors) {
  62.     auto root = tensorflow::Scope::NewRootScope();
  63.     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
  64.  
  65.     // read file_name into a tensor named input
  66.     Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
  67.     TF_RETURN_IF_ERROR(
  68.                 ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
  69.  
  70.     // use a placeholder to read input data
  71.     auto file_reader =
  72.             Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
  73.  
  74.     std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
  75.         {"input", input},
  76.     };
  77.  
  78.     // Now try to figure out what kind of file it is and decode it.
  79.     const int wanted_channels = 3;
  80.     tensorflow::Output image_reader;
  81.     if (tensorflow::StringPiece(file_name).ends_with(".png")) {
  82.         image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
  83.                                  DecodePng::Channels(wanted_channels));
  84.     } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
  85.         // gif decoder returns 4-D tensor, remove the first dim
  86.         image_reader =
  87.                 Squeeze(root.WithOpName("squeeze_first_dim"),
  88.                         DecodeGif(root.WithOpName("gif_reader"), file_reader));
  89.     } else {
  90.         // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
  91.         image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
  92.                                   DecodeJpeg::Channels(wanted_channels));
  93.     }
  94.     // Now cast the image data to float so we can do normal math on it.
  95.     // auto float_caster =
  96.     //     Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
  97.  
  98.     auto uint8_caster =  Cast(root.WithOpName("uint8_caster"), image_reader, tensorflow::DT_UINT8);
  99.  
  100.     // The convention for image ops in TensorFlow is that all images are expected
  101.     // to be in batches, so that they're four-dimensional arrays with indices of
  102.     // [batch, height, width, channel]. Because we only have a single image, we
  103.     // have to add a batch dimension of 1 to the start with ExpandDims().
  104.     auto dims_expander = ExpandDims(root.WithOpName("dim"), uint8_caster, 0);
  105.  
  106.     // Bilinearly resize the image to fit the required dimensions.
  107.     // auto resized = ResizeBilinear(
  108.     //     root, dims_expander,
  109.     //     Const(root.WithOpName("size"), {input_height, input_width}));
  110.  
  111.  
  112.     // Subtract the mean and divide by the scale.
  113.     // auto div =  Div(root.WithOpName(output_name), Sub(root, dims_expander, {input_mean}),
  114.     //     {input_std});
  115.  
  116.  
  117.     //cast to int
  118.     //auto uint8_caster =  Cast(root.WithOpName("uint8_caster"), div, tensorflow::DT_UINT8);
  119.  
  120.     // This runs the GraphDef network definition that we've just constructed, and
  121.     // returns the results in the output tensor.
  122.     tensorflow::GraphDef graph;
  123.     TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
  124.  
  125.     std::unique_ptr<tensorflow::Session> session(
  126.                 tensorflow::NewSession(tensorflow::SessionOptions()));
  127.     TF_RETURN_IF_ERROR(session->Create(graph));
  128.     TF_RETURN_IF_ERROR(session->Run({inputs}, {"dim"}, {}, out_tensors));
  129.     return Status::OK();
  130. }
  131.  
  132. // Reads a model graph definition from disk, and creates a session object you
  133. // can use to run it.
  134. Status LoadGraph(const string& graph_file_name,
  135.                  std::unique_ptr<tensorflow::Session>* session) {
  136.     tensorflow::GraphDef graph_def;
  137.     Status load_graph_status =
  138.             ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
  139.     if(!load_graph_status.ok()) {
  140.         return tensorflow::errors::NotFound("Failed to load compute graph at '",
  141.                                             graph_file_name, "'");
  142.     }
  143.     session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
  144.     Status session_create_status = (*session)->Create(graph_def);
  145.     if (!session_create_status.ok()) {
  146.         return session_create_status;
  147.     }
  148.     return Status::OK();
  149. }
  150.  
  151. }
  152.  
  153. int tensorflow_example(int argc, char* argv[]) {
  154.     // These are the command-line flags the program can understand.
  155.     // They define where the graph and input data is located, and what kind of
  156.     // input the model expects. If you train your own model, or use something
  157.     // other than inception_v3, then you'll need to update these.
  158.     string const image(argv[1]);
  159.     std::string const model_folder("/home/ramsus/Qt/computer_vision_model/tensorflow/object_detection/");
  160.     string const graph = model_folder + "ssd_mobilenet_v1_coco_11_06_2017/frozen_inference_graph.pb";
  161.     int32 input_width = 300;
  162.     int32 input_height = 300;
  163.     float input_mean = 0;
  164.     float input_std = 255;
  165.     string input_layer = "image_tensor:0";
  166.     vector<string> output_layer ={ "detection_boxes:0", "detection_scores:0", "detection_classes:0", "num_detections:0" };
  167.  
  168.     // First we load and initialize the model.
  169.     std::unique_ptr<tensorflow::Session> session;
  170.     string graph_path = graph;
  171.     std::cout<<"graph path:"<<graph_path<<std::endl;
  172.  
  173.     LOG(ERROR) << "graph_path:" << graph_path;
  174.     Status load_graph_status = LoadGraph(graph_path, &session);
  175.     if (!load_graph_status.ok()) {
  176.         LOG(ERROR) << "LoadGraph ERROR!!!!"<< load_graph_status;
  177.         return -1;
  178.     }
  179.  
  180.     std::cout<<"load graph success"<<std::endl;
  181.     // Get the image from disk as a float array of numbers, resized and normalized
  182.     // to the specifications the main graph expects.
  183.     std::vector<Tensor> resized_tensors;
  184.     string const image_path = image;
  185.     Status read_tensor_status = ReadTensorFromImageFile(image_path, &resized_tensors);
  186.     if (!read_tensor_status.ok()) {
  187.         LOG(ERROR) << read_tensor_status;
  188.         return -1;
  189.     }
  190.     const Tensor& resized_tensor = resized_tensors[0];
  191.  
  192.     LOG(ERROR) <<"image shape:" << resized_tensor.shape().DebugString()<< ",len:" << resized_tensors.size() << ",tensor type:"<< resized_tensor.dtype();
  193.     // << ",data:" << resized_tensor.flat<tensorflow::uint8>();
  194.     // Actually run the image through the model.
  195.     std::vector<Tensor> outputs;
  196.     Status run_status = session->Run({{input_layer, resized_tensor}},
  197.                                      output_layer, {}, &outputs);
  198.     if (!run_status.ok()) {
  199.         LOG(ERROR) << "Running model failed: " << run_status;
  200.         return -1;
  201.     }
  202.  
  203.     int image_width = resized_tensor.dims();
  204.     int image_height = 0;
  205.     //int image_height = resized_tensor.shape()[1];
  206.  
  207.     LOG(ERROR) << "size:" << outputs.size() << ",image_width:" << image_width << ",image_height:" << image_height << endl;
  208.  
  209.     //tensorflow::TTypes<float>::Flat iNum = outputs[0].flat<float>();
  210.     tensorflow::TTypes<float>::Flat scores = outputs[1].flat<float>();
  211.  
  212.     tensorflow::TTypes<float>::Flat classes = outputs[2].flat<float>();
  213.     tensorflow::TTypes<float>::Flat num_detections = outputs[3].flat<float>();
  214.     auto boxes = outputs[0].flat_outer_dims<float,3>();
  215.  
  216.     LOG(ERROR) << "num_detections:" << num_detections(0) << "," << outputs[0].shape().DebugString();
  217.  
  218.     for(size_t i = 0; i < num_detections(0) && i < 20;++i){
  219.         if(scores(i) > 0.5){
  220.             LOG(ERROR) << i << ",score:" << scores(i) << ",class:" << classes(i)<< ",box:" << "," << boxes(0,i,0) << "," << boxes(0,i,1) << "," << boxes(0,i,2)<< "," << boxes(0,i,3);
  221.         }
  222.     }
  223. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement