TylerHumanCompiler

envelopecnn.cpp

Jul 9th, 2020
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.75 KB | None | 0 0
  1. #include "envelopecnn.h"
  2.  
  3. using namespace std;
  4. using namespace tensorflow;
  5.  
  6.  
  7. EnvelopeCNN::EnvelopeCNN(string modelpath, string graphpath) {
  8.     SessionOptions options = SessionOptions();
  9.     ConfigProto* config = &options.config;
  10.     (*config->mutable_device_count())["GPU"] = 0;
  11.     config->mutable_gpu_options()->set_visible_device_list("");
  12.     config->set_allow_soft_placement(true);
  13.  
  14.     status = NewSession(options, &session);
  15.     if (status != tensorflow::Status::OK()) {cout << status << endl;}
  16.     data = Tensor(DT_FLOAT, TensorShape({1, 7, 96, 2, 2}));
  17.     status = LoadSavedModel(options, RunOptions(), modelpath, {kSavedModelTagServe}, &model);
  18.     if (status != tensorflow::Status::OK()) {cout << status << endl;}
  19.     status = ReadBinaryProto(Env::Default(), graphpath, &graph_def);
  20.     if (status != tensorflow::Status::OK()) {cout << status << endl;}
  21.     status = session->Create(graph_def);
  22.     if (status != tensorflow::Status::OK()) {cout << status << endl;}  
  23. }
  24.  
  25. void EnvelopeCNN::predict(float**** inp) {
  26.     Tensor data = Tensor(DT_FLOAT, TensorShape({1, 7, 96, 2, 2}));
  27.     float* data_ = dat.flat<float>().data();
  28.     unsigned int count = 0;
  29.     for(unsigned int a = 0; a < 7; a++) {
  30.         for(unsigned int b = 0; b < 96; b++) {
  31.             data_[count] = inp[a][b][0][0];
  32.             data_[count+1] = inp[a][b][0][1];
  33.             data_[count+2] = inp[a][b][1][0];
  34.             data_[count+3] = inp[a][b][0][1];
  35.             count+=4;
  36.         }
  37.     }
  38.  
  39.     tensor_dict feed_dict = {
  40.             {"input", data},
  41.     };
  42.  
  43.     std::vector<tensorflow::Tensor> outputs;
  44.     status = session->Run(feed_dict, {"output0/Reshape", "output1/Reshape", "output2/Reshape", "output3/Reshape"}, {}, &outputs);
  45.     if (status != tensorflow::Status::OK()) {cout << "SHIT HAPPENED:\n" << status << endl;}
  46. }
Add Comment
Please, Sign In to add comment