Advertisement
Guest User

Untitled

a guest
Jun 19th, 2018
198
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.75 KB | None | 0 0
  1. #include <cassert>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <iostream>
  5. #include <cmath>
  6. #include <sys/stat.h>
  7. #include <cmath>
  8. #include <time.h>
  9. #include <cuda_runtime_api.h>
  10. #include <cudnn.h>
  11. #include <cublas_v2.h>
  12. #include <memory>
  13. #include <cstring>
  14. #include <algorithm>
  15. #include <opencv2/opencv.hpp>
  16.  
  17. #include "NvInfer.h"
  18. #include "NvCaffeParser.h"
  19.  
  20.  
  21. using namespace nvinfer1;
  22. using namespace nvcaffeparser1;
  23. using namespace std;
  24.  
  25. #define CHECK(status)                                   \
  26. {                                                       \
  27.     if (status != 0)                                    \
  28.     {                                                   \
  29.         std::cout << "Cuda failure: " << status;        \
  30.         abort();                                        \
  31.     }                                                   \
  32. }
  33.  
  34. // stuff we know about the network and the caffe input/output blobs
  35. static const int INPUT_C = 3;
  36. static const int INPUT_H = 160;
  37. static const int INPUT_W = 320;
  38.  
  39.  
  40. const char* INPUT_BLOB_NAME = "data";
  41. const char* OUTPUT_BLOB_NAME = "seg-prob-region";
  42.  
  43. // Logger for GIE info/warning/errors
  44. class Logger : public ILogger
  45. {
  46.     void log(Severity severity, const char* msg) override
  47.     {
  48.         // suppress info-level messages
  49.         if (severity != Severity::kINFO || false)
  50.             std::cout << msg << std::endl;
  51.     }
  52. } gLogger;
  53.  
  54.  
  55. void caffeToGIEModel(const std::string& deployFile,                 // name for caffe prototxt
  56.                      const std::string& modelFile,                  // name for model
  57.                      const std::vector<std::string>& outputs,       // network outputs
  58.                      unsigned int maxBatchSize,                     // batch size - NB must be at least as large as the batch we want to run with)
  59.                      nvcaffeparser1::IPluginFactory* pluginFactory, // factory for plugin layers
  60.                      IHostMemory *&gieModelStream)                  // output stream for the GIE model
  61. {
  62.     // create the builder
  63.     IBuilder* builder = createInferBuilder(gLogger);
  64.  
  65.     // parse the caffe model to populate the network, then set the outputs
  66.     INetworkDefinition* network = builder->createNetwork();
  67.     ICaffeParser* parser = createCaffeParser();
  68.     parser->setPluginFactory(pluginFactory);
  69.  
  70.  
  71.     const IBlobNameToTensor* blobNameToTensor = parser->parse(deployFile.c_str(),
  72.                                                               modelFile.c_str(),
  73.                                                               *network,
  74.                                                               DataType::kFLOAT);
  75.  
  76.     // specify which tensors are outputs
  77.     for (auto& s : outputs)
  78.     {
  79.         if (blobNameToTensor->find(s.c_str()) == nullptr)
  80.         {
  81.             std::cout << "could not find output blob " << s << std::endl;
  82.             return ;
  83.         }
  84.         network->markOutput(*blobNameToTensor->find(s.c_str()));
  85.     }
  86.  
  87.     // Build the engine
  88.     builder->setMaxBatchSize(maxBatchSize);
  89.     builder->setMaxWorkspaceSize( 1000 * (1 << 20));
  90.     //builder->setHalf2Mode(fp16);
  91.  
  92.     ICudaEngine* engine = builder->buildCudaEngine(*network);
  93.     assert(engine);
  94.  
  95.     // we don't need the network any more, and we can destroy the parser
  96.     network->destroy();
  97.     parser->destroy();
  98.  
  99.     // serialize the engine, then close everything down
  100.     gieModelStream = engine->serialize();
  101.  
  102.     engine->destroy();
  103.     builder->destroy();
  104.     shutdownProtobufLibrary();
  105.  
  106. }
  107.  
  108.  
  109. class ELU_Plugin: public IPlugin
  110. {
  111. public:
  112.  
  113.  
  114.     ELU_Plugin ()
  115.     {
  116.  
  117.     }
  118.  
  119.     ELU_Plugin(const Weights *weights, int nbWeights, int nbOutputChannels): mNbOutputChannels(nbOutputChannels)
  120.     {
  121.  
  122.     }
  123.     // create the plugin at runtime from a byte stream
  124.     ELU_Plugin(const void* data, size_t length)
  125.     {
  126.  
  127.     }
  128.  
  129.     ~ELU_Plugin()
  130.     {
  131.  
  132.     }
  133.  
  134.     int getNbOutputs() const override
  135.     {
  136.         return 1;
  137.     }
  138.  
  139.     Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
  140.     {
  141.         cout << "ELU Shape " << inputs[0].d[0] << ", " << inputs[0].d[1] << ", " << inputs[0].d[2] <<endl;
  142.         return DimsCHW(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
  143.     }
  144.  
  145.     void configure(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) override
  146.     {
  147.  
  148.     }
  149.  
  150.     int initialize() override
  151.     {
  152.         CHECK(cudnnCreate(&mCudnn));                            // initialize cudnn and cublas
  153.         CHECK(cublasCreate(&mCublas));
  154.         CHECK(cudnnCreateTensorDescriptor(&mSrcDescriptor));    // create cudnn tensor descriptors we need for bias addition
  155.         CHECK(cudnnCreateTensorDescriptor(&mDstDescriptor));
  156.         CHECK(cudnnCreateActivationDescriptor(&mActDescriptor));
  157.         return 0;
  158.     }
  159.  
  160.     void terminate() override
  161.     {
  162.         CHECK(cublasDestroy(mCublas));
  163.         CHECK(cudnnDestroy(mCudnn));
  164.     }
  165.  
  166.     size_t getWorkspaceSize(int maxBatchSize) const override
  167.     {
  168.         return 0;
  169.     }
  170.  
  171.  
  172.     int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override
  173.     {
  174.  
  175.         printf("Enqueue start \n");
  176.         const float kONE = 1.0f, kZERO = 0.0f;
  177.         cublasSetStream(mCublas, stream);
  178.         cudnnSetStream(mCudnn, stream);
  179.         CHECK(cudnnSetActivationDescriptor(mActDescriptor, CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, kONE));
  180.         CHECK(cudnnActivationForward(mCudnn, mActDescriptor, &kONE, mSrcDescriptor, inputs, &kZERO, mDstDescriptor, outputs));
  181.  
  182.         return 0;
  183.     }
  184.  
  185.     size_t getSerializationSize() override
  186.     {
  187.         return 0;
  188.     }
  189.  
  190.     void serialize(void* buffer) override
  191.     {
  192.  
  193.     }
  194. private:
  195.     int mNbOutputChannels, mNbInputChannels;
  196.     cudnnHandle_t mCudnn;
  197.     cublasHandle_t mCublas;
  198.     cudnnTensorDescriptor_t mSrcDescriptor, mDstDescriptor;
  199.     cudnnActivationDescriptor_t mActDescriptor;
  200. };
  201.  
  202. // integration for serialization
  203. class PluginFactory : public nvinfer1::IPluginFactory, public nvcaffeparser1::IPluginFactory
  204. {
  205. public:
  206.  
  207.     // caffe parser plugin implementation
  208.     bool isPlugin(const char* name) override
  209.     {
  210.         string s (name);
  211.  
  212.         cout << "Layer name: " << s << endl;
  213.  
  214.         if (s.substr(0, 3) == "elu")
  215.             return 1;
  216.         else
  217.             return 0;
  218.     }
  219.  
  220.     virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const nvinfer1::Weights* weights, int nbWeights) override
  221.     {
  222.         // there's no way to pass parameters through from the model definition, so we have to define it here explicitly
  223.         mPlugin = std::unique_ptr<ELU_Plugin>(new ELU_Plugin());
  224.         return mPlugin.get();
  225.     }
  226.  
  227.     // deserialization plugin implementation
  228.     IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
  229.     {
  230.         mPlugin = std::unique_ptr<ELU_Plugin>(new ELU_Plugin(serialData, serialLength));
  231.         return mPlugin.get();
  232.     }
  233.  
  234.     // the application has to destroy the plugin when it knows it's safe to do so
  235.     void destroyPlugin()
  236.     {
  237.         mPlugin.release();
  238.     }
  239.  
  240.     std::unique_ptr<ELU_Plugin> mPlugin{ nullptr };
  241. };
  242.  
  243. int main(int argc, char** argv)
  244. {
  245.     // create a GIE model from the caffe model and serialize it to a stream
  246.     PluginFactory pluginFactory;
  247.     IHostMemory *gieModelStream{ nullptr };
  248.     caffeToGIEModel("/home/models/u-net0/deploy_short.prototxt", "/home/models/u-net0/model.caffemodel", std::vector < std::string > { OUTPUT_BLOB_NAME }, 1, &pluginFactory, gieModelStream);
  249.     pluginFactory.destroyPlugin();
  250.  
  251.  
  252.     return 0;
  253. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement