Guest User

Untitled

a guest
Jan 17th, 2019
101
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.38 KB | None | 0 0
  1. #include <op_boilerplate.h>
  2. #include <pointercast.h>
  3. #include <NativeOps.h>
  4. #include <graph/Node.h>
  5. #include <graph/Variable.h>
  6. #include <graph/VariableSpace.h>
  7. #include <NDArray.h>
  8. #include <cnpy.h>
  9. #include <ops/ops.h>
  10. #include <helpers/shape.h>
  11. #include <ops/gemm.h>
  12. #include <GraphExecutioner.h>
  13. #include <iostream>
  14. #include <vector>
  15. #include <stdio.h>
  16. #include <libpng/png.h>
  17.  
  18. using std::cout;
  19. using std::cerr;
  20. using std::endl;
  21. using std::string;
  22.  
  23. using namespace nd4j;
  24.  
  25.  
  26. void abort_(const char * s, ...)
  27. {
  28. va_list args;
  29. va_start(args, s);
  30. vfprintf(stderr, s, args);
  31. fprintf(stderr, "\n");
  32. va_end(args);
  33. abort();
  34. }
  35.  
  36. void read_size(const char *file_name, int& width, int& height){
  37. png_structp png_ptr;
  38. png_infop info_ptr;
  39.  
  40. unsigned char header[8]; // 8 is the maximum size that can be checked
  41.  
  42. /* open file and test for it being a png */
  43. FILE *fp = fopen(file_name, "rb");
  44. if (!fp)
  45. abort_("[read_png_file] File %s could not be opened for reading", file_name);
  46. fread(header, 1, 8, fp);
  47. if (png_sig_cmp(header, 0, 8))
  48. abort_("[read_png_file] File %s is not recognized as a PNG file", file_name);
  49.  
  50.  
  51. /* initialize stuff */
  52. png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
  53.  
  54. if (!png_ptr)
  55. abort_("[read_png_file] png_create_read_struct failed");
  56.  
  57. info_ptr = png_create_info_struct(png_ptr);
  58. if (!info_ptr)
  59. abort_("[read_png_file] png_create_info_struct failed");
  60.  
  61. if (setjmp(png_jmpbuf(png_ptr)))
  62. abort_("[read_png_file] Error during init_io");
  63.  
  64. png_init_io(png_ptr, fp);
  65. png_set_sig_bytes(png_ptr, 8);
  66.  
  67. png_read_info(png_ptr, info_ptr);
  68.  
  69. width = png_get_image_width(png_ptr, info_ptr);
  70. height = png_get_image_height(png_ptr, info_ptr);
  71. }
  72.  
  73. void read_png_file(const char *file_name, NDArray* output) {
  74.  
  75. int x, y;
  76.  
  77. int width, height;
  78. png_byte color_type;
  79. png_byte bit_depth;
  80.  
  81. png_structp png_ptr;
  82. png_infop info_ptr;
  83. int number_of_passes;
  84. png_bytep * row_pointers;
  85.  
  86. unsigned char header[8]; // 8 is the maximum size that can be checked
  87.  
  88. /* open file and test for it being a png */
  89. FILE *fp = fopen(file_name, "rb");
  90. if (!fp)
  91. abort_("[read_png_file] File %s could not be opened for reading", file_name);
  92. fread(header, 1, 8, fp);
  93. if (png_sig_cmp(header, 0, 8))
  94. abort_("[read_png_file] File %s is not recognized as a PNG file", file_name);
  95.  
  96.  
  97. /* initialize stuff */
  98. png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
  99.  
  100. if (!png_ptr)
  101. abort_("[read_png_file] png_create_read_struct failed");
  102.  
  103. info_ptr = png_create_info_struct(png_ptr);
  104. if (!info_ptr)
  105. abort_("[read_png_file] png_create_info_struct failed");
  106.  
  107. if (setjmp(png_jmpbuf(png_ptr)))
  108. abort_("[read_png_file] Error during init_io");
  109.  
  110. png_init_io(png_ptr, fp);
  111. png_set_sig_bytes(png_ptr, 8);
  112.  
  113. png_read_info(png_ptr, info_ptr);
  114.  
  115. width = png_get_image_width(png_ptr, info_ptr);
  116. height = png_get_image_height(png_ptr, info_ptr);
  117. color_type = png_get_color_type(png_ptr, info_ptr);
  118. bit_depth = png_get_bit_depth(png_ptr, info_ptr);
  119. printf("bpp: %d\n", bit_depth);
  120.  
  121. number_of_passes = png_set_interlace_handling(png_ptr);
  122. png_read_update_info(png_ptr, info_ptr);
  123.  
  124.  
  125. /* read file */
  126. if (setjmp(png_jmpbuf(png_ptr)))
  127. abort_("[read_png_file] Error during read_image");
  128.  
  129. row_pointers = (png_bytep*) malloc(sizeof(png_bytep) * height);
  130. for (y=0; y<height; y++)
  131. row_pointers[y] = (png_byte*) malloc(png_get_rowbytes(png_ptr,info_ptr));
  132.  
  133. png_read_image(png_ptr, row_pointers);
  134. printf("rowbytes: %d \n", png_get_rowbytes(png_ptr,info_ptr) );
  135. for(int y = 0; y < height; y++) {
  136. for(int x = 0; x < width; x++) {
  137. unsigned char r,g,b;
  138. r = ((unsigned char *)(row_pointers[y]))[x*3];
  139. g = ((unsigned char *)(row_pointers[y]))[x*3+1];
  140. b = ((unsigned char *)(row_pointers[y]))[x*3+2];
  141. output->p<float>(0,y,x,0, (((float)(r))/128.0f) + 1.0f);
  142. output->p<float>(0,y,x,1, (((float)(g))/128.0f) + 1.0f);
  143. output->p<float>(0,y,x,2, (((float)(b))/128.0f) + 1.0f);
  144. /*output->p<float>(0,y,x,0, (float)(r)/255.0f);
  145. output->p<float>(0,y,x,1, (float)(g)/255.0f);
  146. output->p<float>(0,y,x,2, (float)(b)/255.0f);*/
  147. if(y==50){
  148. //printf("x=%d, y=%d, pixel = %d %d %d %f\n", x,y,r,g,b, 1.0f-(float)(r)/128.0f);
  149. }
  150. }
  151. }
  152. fclose(fp);
  153. }
  154.  
  155.  
  156. void read_csv(string filename, NDArray* output, int width){
  157. char * line;
  158. FILE *fp;
  159. size_t size=256;
  160. fp = fopen(filename.c_str(),"r");
  161. if(!fp) {
  162. printf("Could not open file %s\n", filename.c_str());
  163. exit(-1);
  164. }
  165. int ind=0;
  166. line = (char*)(malloc(256));
  167. ssize_t read=1;
  168. read=getline(&line, &size, fp);
  169. while(read>0){
  170. int x = (ind/3) % width;
  171. int y = ind/(3*width);
  172. output->p<float>(0,y,x, ind%3, atof(line));
  173. read=getline(&line, &size, fp);
  174. ind++;
  175. }
  176. free(line);
  177. return;
  178. }
  179.  
  180. void showVariable(Graph* graph){
  181. std::vector<Variable*> variablesVector = graph->getVariableSpace()->getVariables();
  182. for(int i=0;i<variablesVector.size();i++){
  183. Variable* var = variablesVector[i];
  184. //varNames[*(var->getName())] = var;
  185. printf("_%s_ \n", var->getName()->c_str());
  186. }
  187.  
  188. }
  189.  
  190. int main(int argc, char** argv){
  191. /*if(argc<2){
  192. cerr << "Missing argument." << endl << "Usage: ./GraphExecutor model.fb [image.png] [input_layer_name] [output_layer_name]" <<endl;
  193. return -1;
  194. }*/
  195.  
  196. string inFilename = "/home/yves/dl4j/models/flatBufferModels/master_version/mobilenet_v1_0.5_128_frozen.fb";
  197. string imageFile = "/home/yves/dl4j/datasets/carsAndCats_128/cat1.png";
  198. string inputLayerName = "input";
  199. string outputLayerName = "MobilenetV1/Predictions/Reshape_1";
  200. if(argc>1) inFilename.assign(argv[1]);
  201. if(argc>2) imageFile.assign(argv[2]);
  202. if(argc>3) inputLayerName.assign(argv[3]);
  203. if(argc>4) outputLayerName.assign(argv[4]);
  204.  
  205.  
  206. /*nd4j::Environment::getInstance()->setElementwiseThreshold(100000000);
  207. nd4j::Environment::getInstance()->setTadThreshold(100000000);
  208. nd4j::Environment::getInstance()->setDebug(false);
  209. nd4j::Environment::getInstance()->setVerbose(false);*/
  210.  
  211. printf("Import FlatBuffer\n");
  212. auto graph = GraphExecutioner::importFromFlatBuffers(inFilename.c_str());
  213. printf("Build Graph\n");
  214. graph->buildGraph();
  215.  
  216. showVariable(graph);
  217. //graph->printOut();
  218.  
  219. Variable* input = graph->getVariableSpace()->getVariable(&inputLayerName);
  220. std::vector<Nd4jLong> shape;
  221. int width, height;
  222. read_size(imageFile.c_str(), width, height);
  223.  
  224. shape.push_back(1);
  225. shape.push_back(height);
  226. shape.push_back(width);
  227. shape.push_back(3);
  228.  
  229. NDArray* inputArray = new NDArray('f', shape, nd4j::DataType::FLOAT32);
  230. inputArray->assign(0.0f);
  231. //read_png_file(imageFile.c_str(), &inputArray);
  232. //read_csv("/home/yves/dl4j/datasets/mobilev1.csv", &inputArray, width);
  233. //read_png_file(imageFile.c_str(), graph->getVariableSpace()->getVariable(&inputLayerName)->getNDArray());
  234.  
  235. input->setNDArray(inputArray);
  236. printf("Before:%s \n", input->getName()->c_str());
  237. Variable *v = input->clone();
  238. printf("After:%s \n", v->getName()->c_str());
  239.  
  240. //graph->getVariableSpace()->replaceVariable(new Variable(input));
  241. graph->getVariableSpace()->replaceVariable(input->clone());
  242.  
  243. printf("Executing graph\n");
  244. Nd4jStatus status = GraphExecutioner::execute(graph);
  245. printf("Execution finished\n");
  246. std::vector<Variable*> results = *graph->fetchOutputs();
  247.  
  248.  
  249. auto resultVar = graph->getVariableSpace()->getVariable(&outputLayerName);
  250. printf("Empty? %d\n", resultVar->isEmpty());
  251.  
  252. NDArray* result = graph->getVariableSpace()->getVariable(&outputLayerName)->getNDArray();
  253.  
  254. //resultVar = graph->fetchOutputs()->at(0);
  255. //result = resultVar->getNDArray();
  256. printf("%s: \n", resultVar->getName()->c_str());
  257. std::vector<float> rvec = result->getBufferAsVector<float>();
  258. printf("argMax %d: %f \n", result->argMax(), rvec[result->argMax()]);
  259. /*for( int i=0;i<rvec.size();i++){
  260. if(rvec[i]>0.001)
  261. printf("%d: %f \n", i, rvec[i]);
  262. }
  263. printf("\n");*/
  264. nd4j_printf("Execution status=%d\n",status);
  265. }
Add Comment
Please, Sign In to add comment