Guest User

Untitled

a guest
Aug 10th, 2018
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.06 KB | None | 0 0
  1. #include <wrl/client.h>
  2. #include <winml.h>
  3. #include <stdio.h>
  4. using Microsoft::WRL::ComPtr;
  5.  
  6. #include"cnpy.h"
  7.  
  8. #define PRINTDBG
  9.  
  10. int main(int argc,char**argv)
  11. {
  12. //load model
  13.  
  14. ComPtr<IWinMLRuntime> spRuntime;
  15. HRESULT res= WinMLCreateRuntime(&spRuntime);
  16. printf("%d\n", res);
  17. ComPtr<IWinMLModel> spModel;
  18. res=spRuntime->LoadModel(L"C:\\src4\\winml\\mnist.onnx", &spModel);
  19. printf("%d\n", res);
  20.  
  21. //view graph metadata
  22.  
  23. WINML_MODEL_DESC *desc;
  24. res = spModel->GetDescription(&desc);
  25. printf("%d\n", res);
  26. int count = 0;
  27. LPCWSTR key,value;
  28. res=spModel->EnumerateMetadata(count, &key, &value);
  29. printf("%d\n", res);
  30. WINML_VARIABLE_DESC *desc3;
  31. res = spModel->EnumerateModelInputs(count, &desc3);
  32. printf("%d\n", res);
  33. printf("%S %d %d %d %d %d\n", desc3->Name,desc3->Tensor.NumDimensions, desc3->Tensor.pShape[0], desc3->Tensor.pShape[1], desc3->Tensor.pShape[2], desc3->Tensor.pShape[3]);
  34. WINML_VARIABLE_DESC *desc2;
  35. res = spModel->EnumerateModelOutputs(count, &desc2);
  36. printf("%S %d %d %d %d %d\n", desc2->Name,desc2->Tensor.NumDimensions, desc2->Tensor.pShape[0], desc2->Tensor.pShape[1], desc2->Tensor.pShape[2], desc2->Tensor.pShape[3]);
  37. printf("%d\n", res);
  38.  
  39. //bind resources
  40.  
  41. //setup device to use for inferencing
  42. ComPtr<IWinMLEvaluationContext> spContext;
  43. ComPtr<ID3D12Device> spDevice;
  44. res=spRuntime->CreateEvaluationContext(spDevice.Get(), &spContext);
  45. printf("%d\n", res);
  46.  
  47. //connect I/O data
  48. WINML_BINDING_DESC bindDescriptor;
  49. bindDescriptor.BindType = WINML_BINDING_TYPE::WINML_BINDING_TENSOR;
  50. bindDescriptor.Tensor.DataType=WINML_TENSOR_DATA_TYPE::WINML_TENSOR_FLOAT;
  51. bindDescriptor.Tensor.NumDimensions = 4;
  52. INT64 shape[4] = { 1,1,28,28 };
  53. bindDescriptor.Tensor.pShape = reinterpret_cast<INT64*>(&shape);
  54. //void *data=(void *)malloc(28 * 28 * sizeof(float));
  55. bindDescriptor.Tensor.DataSize = 28 * 28 * sizeof(float);
  56.  
  57. char datafile[100] = { 0 };
  58. int nt = 0;
  59. if (argc == 2)
  60. nt = atoi(argv[1]);
  61. sprintf_s(datafile,99,"C:\\src4\\winml\\test_data_%d.npz", nt);
  62. //"C:\\src4\\winml\\test_data_0.npz"
  63. cnpy::NpyArray arr2 = cnpy::npz_load(datafile, "inputs");
  64. int comp = 0;
  65. comp = arr2.word_size == sizeof(float);
  66. comp = arr2.shape.size() == 3;
  67. comp = arr2.shape[0] == 1;
  68. comp = arr2.shape[1] == 28;
  69. comp = arr2.shape[2] == 28;
  70. float* mv1 = arr2.data<float>();
  71. for (int i = 0; i < 28; i++)
  72. {
  73. for (int j = 0; j < 28; j++)
  74. {
  75. char c = ' ';
  76. if (mv1[i * 28 + j] > 0.10)
  77. c = '*';
  78. else
  79. c = ' ';
  80. printf("%c", c);
  81. //printf("%.1f ", mv1[i*28+j]);
  82. }
  83. printf("\n");
  84. }
  85.  
  86. bindDescriptor.Tensor.pData = /*data*/mv1;
  87. bindDescriptor.Name = desc3->Name;// LPCWSTR("hola");
  88. res=spContext->BindValue(&bindDescriptor);
  89. printf("%d\n", res);
  90.  
  91. WINML_BINDING_DESC bindDescriptor2;
  92. bindDescriptor2.BindType = WINML_BINDING_TYPE::WINML_BINDING_TENSOR;
  93. bindDescriptor2.Tensor.DataType = WINML_TENSOR_DATA_TYPE::WINML_TENSOR_FLOAT;
  94. bindDescriptor2.Tensor.NumDimensions = 2;
  95. INT64 shape2[2] = { 1,10 };
  96. bindDescriptor2.Tensor.pShape = reinterpret_cast<INT64*>(&shape2);
  97. float *data2 = (float *)calloc(1 * 10 , sizeof(float));
  98.  
  99. #ifdef PRINTDBG
  100. for (int i = 0; i < 10; i++)
  101. {
  102. printf("%f ", data2[i]);
  103. }
  104. printf("\n");
  105. #endif
  106. bindDescriptor2.Tensor.DataSize = 1 * 10 * sizeof(float);
  107. bindDescriptor2.Tensor.pData = data2;
  108. bindDescriptor2.Name = desc2->Name;// LPCWSTR("hola");
  109. res = spContext->BindValue(&bindDescriptor2);
  110. printf("%d\n", res);
  111.  
  112.  
  113.  
  114. cnpy::NpyArray arr3 = cnpy::npz_load(datafile, "outputs");
  115. comp = 0;
  116. comp = arr3.word_size == sizeof(float);
  117. comp = arr3.shape.size() == 2;
  118. comp = arr3.shape[0] == 1;
  119. comp = arr3.shape[1] == 10;
  120.  
  121. float* mv2 = arr3.data<float>();
  122. #ifdef PRINTDBG
  123. for (int i = 0; i < 10; i++)
  124. {
  125. printf("%.2f ", mv2[i]);
  126. }
  127. printf("\n");
  128. #endif
  129.  
  130. //Evaluate model (inference)
  131. res=spRuntime->EvaluateModel(spContext.Get());
  132.  
  133. //process results
  134. for (int i = 0; i < 10; i++)
  135. {
  136. printf("%.2f ", data2[i]);
  137. }
  138. printf("\n");
  139. int maxIndex = 0;
  140. float maxProbability = 0.0;
  141. for (int i = 0; i < 10 ; i++)
  142. {
  143. if (data2[i] > maxProbability)
  144. {
  145. maxIndex = i;
  146. maxProbability = data2[i];
  147. }
  148. }
  149. printf("numero detectado %d\n", maxIndex);
  150. }
Add Comment
Please, Sign In to add comment