Advertisement
Guest User

Untitled

a guest
Dec 3rd, 2019
139
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 9.49 KB | None | 0 0
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4.  
  5. #include <assert.h>
  6. #include <sys/time.h>
  7. #include <unistd.h>
  8.  
  9. #include <ros/ros.h>
  10. #include <image_transport/image_transport.h>
  11. #include <cv_bridge/cv_bridge.h>
  12.  
  13. #include "dnn_detect/DetectedObject.h"
  14. #include "dnn_detect/DetectedObjectArray.h"
  15. #include "dnn_detect/Detect.h"
  16.  
  17. #include <opencv2/highgui.hpp>
  18. #include <opencv2/dnn.hpp>
  19. #include <opencv2/calib3d.hpp>
  20.  
  21. #include <list>
  22. #include <string>
  23. #include <boost/algorithm/string.hpp>
  24. #include <boost/format.hpp>
  25.  
  26. #include <thread>
  27. #include <mutex>
  28. #include <condition_variable>
  29.  
  30. #include "objectDetection.hpp"
  31.  
  32. using namespace std;
  33. using namespace cv;
  34.  
  35. string yoloBasePath = "/home/phil/Development/RosDev/opencv_ws/src/opencv_processing/dat/"; // relative paths not working
  36. string yoloClassesFile = yoloBasePath + "coco.names";
  37. string yoloModelConfiguration = yoloBasePath + "yolov3-tiny.cfg";
  38. string yoloModelWeights = yoloBasePath + "yolov3-tiny.weights";
  39. float confThreshold = 0.01f;
  40. float nmsThreshold = 0.9f;
  41.  
  42. std::condition_variable cond;
  43. std::mutex mutx;
  44.  
  45. class DnnNode {
  46.   private:
  47.     ros::Publisher results_pub;
  48.  
  49.     image_transport::ImageTransport it;
  50.     image_transport::Subscriber img_sub;
  51.  
  52.     // if set, we publish the images that contain objects
  53.     bool publish_images;
  54.  
  55.     int frame_num;
  56.     float min_confidence;
  57.     int im_size;
  58.     int rotate_flag;
  59.     float scale_factor;
  60.     float mean_val;
  61.     std::vector<std::string> class_names;
  62.  
  63.     image_transport::Publisher image_pub;
  64.  
  65.     cv::dnn::Net net;
  66.     cv::Mat resized_image;
  67.     cv::Mat rotated_image;
  68.  
  69.     bool single_shot;
  70.     volatile bool triggered;
  71.     volatile bool processed;
  72.  
  73.     dnn_detect::DetectedObjectArray results;
  74.  
  75.     ros::ServiceServer detect_srv;
  76.  
  77.     bool trigger_callback(dnn_detect::Detect::Request &req,
  78.                           dnn_detect::Detect::Response &res);
  79.  
  80.     void image_callback(const sensor_msgs::ImageConstPtr &msg);
  81.  
  82.   public:
  83.     DnnNode(ros::NodeHandle &nh);
  84. };
  85.  
  86. bool DnnNode::  trigger_callback(dnn_detect::Detect::Request &req,
  87.                                  dnn_detect::Detect::Response &res)
  88. {
  89.     ROS_INFO("Got service request");
  90.     triggered = true;
  91.  
  92.     std::unique_lock<std::mutex> lock(mutx);
  93.  
  94.     while (!processed) {
  95.       cond.wait(lock);
  96.     }
  97.     res.result = results;
  98.     processed = false;
  99.     return true;
  100. }
  101.  
  102.  
  103. void DnnNode::image_callback(const sensor_msgs::ImageConstPtr & msg)
  104. {
  105.     if (single_shot && !triggered) {
  106.         return;
  107.     }
  108.     triggered = false;
  109.  
  110.     ROS_INFO("Got image %d", msg->header.seq);
  111.     frame_num++;
  112.  
  113.     cv_bridge::CvImagePtr cv_ptr;
  114.  
  115.     try {
  116.         cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8);
  117.  
  118.         int w = cv_ptr->image.cols;
  119.         int h = cv_ptr->image.rows;
  120.  
  121.         if (rotate_flag >= 0) {
  122.           cv::rotate(cv_ptr->image, rotated_image, rotate_flag);
  123.           rotated_image.copyTo(cv_ptr->image);
  124.         }
  125.  
  126.         // load class names from file
  127.         vector<string> classes;
  128.         ifstream ifs(yoloClassesFile.c_str());
  129.         string line;
  130.         while (getline(ifs, line)) classes.push_back(line);
  131.  
  132.         cv::resize(cv_ptr->image, resized_image, cvSize(im_size, im_size));
  133.         //cv::Mat img = cv_ptr->image;
  134.         cv::Mat blob;
  135.         vector<cv::Mat> netOutput;
  136.         double scalefactor = 1/255.0;
  137.         cv::Size size = cv::Size(320, 320); // 416/416, 320/320, 608/608
  138.         cv::Scalar mean = cv::Scalar(0,0,0);
  139.         bool swapRB = false;
  140.         bool crop = false;
  141.         cv::dnn::blobFromImage(resized_image, blob, scalefactor, size, mean, swapRB, crop);
  142.  
  143.         // Get names of output layers
  144.         vector<cv::String> names;
  145.         vector<int> outLayers = net.getUnconnectedOutLayers(); // get  indices of  output layers, i.e.  layers with unconnected outputs
  146.         vector<cv::String> layersNames = net.getLayerNames(); // get  names of all layers in the network
  147.  
  148.         names.resize(outLayers.size());
  149.         for (size_t i = 0; i < outLayers.size(); ++i) // Get the names of the output layers in names
  150.             names[i] = layersNames[outLayers[i] - 1];
  151.  
  152.         net.setInput(blob);
  153.         net.forward(netOutput, names);
  154.  
  155.         // Scan through all bounding boxes and keep only the ones with high confidence
  156.         vector<int> classIds; vector<float> confidences; vector<cv::Rect> boxes;
  157.         for (size_t i = 0; i < netOutput.size(); ++i)
  158.         {
  159.             float* data = (float*)netOutput[i].data;
  160.             for (int j = 0; j < netOutput[i].rows; ++j, data += netOutput[i].cols)
  161.             {
  162.                 cv::Mat scores = netOutput[i].row(j).colRange(5, netOutput[i].cols);
  163.                 cv::Point classId;
  164.                 double confidence;
  165.  
  166.                 // Get the value and location of the maximum score
  167.                 cv::minMaxLoc(scores, 0, &confidence, 0, &classId);
  168.                 if (confidence > confThreshold)
  169.                 {
  170.                     cv::Rect box; int cx, cy;
  171.                     cx = (int)(data[0] * img.cols);
  172.                     cy = (int)(data[1] * img.rows);
  173.                     box.width = (int)(data[2] * img.cols);
  174.                     box.height = (int)(data[3] * img.rows);
  175.                     box.x = cx - box.width/2; // left
  176.                     box.y = cy - box.height/2; // top
  177.  
  178.                     boxes.push_back(box);
  179.                     classIds.push_back(classId.x);
  180.                     confidences.push_back((float)confidence);
  181.                 }
  182.             }
  183.         }
  184.  
  185.  
  186.         std::unique_lock<std::mutex> lock(mutx);
  187.         results.header.frame_id = msg->header.frame_id;
  188.         results.objects.clear();
  189.  
  190.         // perform non-maxima suppression
  191.         vector<int> indices;
  192.         cv::dnn::NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
  193.  
  194.         for(auto it=indices.begin(); it!=indices.end(); ++it) {
  195.  
  196.             std::string label = str(boost::format{"%1% %2%"} % classIds[*it] % confidences[*it]);
  197.             ROS_INFO("%s", label.c_str());
  198.  
  199.             dnn_detect::DetectedObject obj;
  200.             int x_min = boxes[*it].x;
  201.             int x_max = boxes[*it].x + boxes[*it].width;
  202.             int y_min = boxes[*it].y;
  203.             int y_max = boxes[*it].y + boxes[*it].height;
  204.             obj.x_min = x_min;
  205.             obj.x_max = x_max;
  206.             obj.y_min = y_min;
  207.             obj.y_max = y_max;
  208.             obj.class_name = classIds[*it];
  209.             obj.confidence = confidences[*it];
  210.             results.objects.push_back(obj);
  211.  
  212.             Rect object(x_min, y_min, x_max-x_min, y_max-y_min);
  213.  
  214.             rectangle(cv_ptr->image, object, Scalar(0, 255, 0));
  215.             int baseline=0;
  216.             cv::Size text_size = cv::getTextSize(label,
  217.                                  FONT_HERSHEY_SIMPLEX, 0.75, 2, &baseline);
  218.             putText(cv_ptr->image, label, Point(x_min, y_min-text_size.height),
  219.                     FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0));
  220.         }
  221.  
  222.  
  223.     results_pub.publish(results);
  224.  
  225.     image_pub.publish(cv_ptr->toImageMsg());
  226.  
  227.     }
  228.     catch(cv_bridge::Exception & e) {
  229.         ROS_ERROR("cv_bridge exception: %s", e.what());
  230.     }
  231.     catch(cv::Exception & e) {
  232.         ROS_ERROR("cv exception: %s", e.what());
  233.     }
  234.     ROS_DEBUG("Notifying condition variable");
  235.     processed = true;
  236.     cond.notify_all();
  237. }
  238.  
  239. DnnNode::DnnNode(ros::NodeHandle & nh) : it(nh)
  240. {
  241.     frame_num = 0;
  242.  
  243.     std::string dir;
  244.     std::string proto_net_file;
  245.     std::string caffe_model_file;
  246.     std::string classes("background,"
  247.        "aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,"
  248.        "cow,diningtable,dog,horse,motorbike,person,pottedplant,"
  249.        "sheep,sofa,train,tvmonitor");
  250.  
  251.  
  252.  
  253.     nh.param<bool>("single_shot", single_shot, false);
  254.     nh.param<bool>("publish_images", publish_images, false);
  255.     nh.param<string>("data_dir", dir, "");
  256.     nh.param<string>("protonet_file", proto_net_file, "MobileNetSSD_deploy.prototxt.txt");
  257.     nh.param<string>("caffe_model_file", caffe_model_file, "MobileNetSSD_deploy.caffemodel");
  258.     nh.param<float>("min_confidence", min_confidence, 0.01);
  259.     nh.param<int>("im_size", im_size, 320);
  260.     nh.param<int>("rotate_flag", rotate_flag, -1);
  261.     nh.param<float>("scale_factor", scale_factor, 1/255.0);
  262.     nh.param<float>("mean_val", mean_val, 127.5f);
  263.     nh.param<std::string>("class_names", classes, classes);
  264.  
  265.     boost::split(class_names, classes, boost::is_any_of(","));
  266.     ROS_INFO("Read %d class names", (int)class_names.size());
  267.  
  268.     try {
  269.         net = cv::dnn::readNetFromDarknet(yoloModelConfiguration, yoloModelWeights);
  270.         net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
  271.         net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA_FP16);
  272.     }
  273.     catch(cv::Exception & e) {
  274.         ROS_ERROR("cv exception: %s", e.what());
  275.         exit(1);
  276.     }
  277.  
  278.     triggered = false;
  279.  
  280.     detect_srv = nh.advertiseService("detect", &DnnNode::trigger_callback, this);
  281.  
  282.     results_pub =
  283.         nh.advertise<dnn_detect::DetectedObjectArray>("/dnn_objects", 20);
  284.  
  285.     image_pub = it.advertise("/dnn_images", 1);
  286.  
  287.     img_sub = it.subscribe("/camera", 1,
  288.                            &DnnNode::image_callback, this);
  289.  
  290.     ROS_INFO("DNN detection ready");
  291. }
  292.  
  293. int main(int argc, char ** argv) {
  294.     ros::init(argc, argv, "dnn_detect");
  295.     ros::NodeHandle nh("~");
  296.  
  297.     DnnNode node = DnnNode(nh);
  298.     ros::MultiThreadedSpinner spinner(2);
  299.     spinner.spin();
  300.  
  301.     return 0;
  302. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement