Advertisement
Guest User

Untitled

a guest
Apr 4th, 2020
249
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.63 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <cstddef>
  4. #include <fstream>
  5. #include <string>
  6. #include <array>
  7.  
  8. struct Image {
  9.     std::array<uint8_t, 28*28> data;
  10.     uint8_t label;
  11.  
  12.     Image(uint8_t label, const std::array<uint8_t, 28*28>& data) : label(label), data(data) {}
  13.     uint8_t at(int row, int col) const {
  14.         return data[col + row * 28];
  15.     }
  16. };
  17. std::ostream& operator<<(std::ostream& out, const Image& image) {
  18.     out << "label = " << static_cast<int>(image.label) << '\n';
  19.     for (int row = 0; row < 28; ++row) {
  20.         for (int col = 0; col < 28; ++col) {
  21.             out << ((image.at(row, col) >= 128) ? '#' : ' ');
  22.         }
  23.         out << '\n';
  24.     }
  25.     return out;
  26. }
  27.  
  28. template<int Size> std::array<uint8_t, Size> readBytes(std::istream& in) {
  29.     std::array<uint8_t, Size> buf;
  30.     in.read((char*)&buf[0], Size);
  31.     return buf;
  32. }
  33. uint32_t read32BigEndianInt(std::istream& in) {
  34.     auto buf = readBytes<4>(in);
  35.     return (buf[0] << 24) | (buf[1] << 16) | (buf[2] << 8) | buf[3];
  36. }
  37.  
  38. std::vector<Image> readDataSet(const std::string& labelFileName, const std::string& imageFileName) {
  39.     auto withErrorMessage = [](const std::string& message) -> std::vector<Image> {
  40.         std::cerr << message << '\n';
  41.         return {};
  42.     };
  43.  
  44.     std::ifstream labelFile(labelFileName, std::ios::binary);
  45.     std::ifstream imageFile(imageFileName, std::ios::binary);
  46.  
  47.     if (!labelFile)
  48.         return withErrorMessage("Couldn't open label file");
  49.     if (!imageFile)
  50.         return withErrorMessage("Couldn't open image file");
  51.  
  52.     auto labelMagicNumber = read32BigEndianInt(labelFile);
  53.     auto imageMagicNumber = read32BigEndianInt(imageFile);
  54.     if (labelMagicNumber != 2049)
  55.         return withErrorMessage("Magic number for label file is incorrect");
  56.     if (imageMagicNumber != 2051)
  57.         return withErrorMessage("Magic number for image file is incorrect");
  58.  
  59.     auto labelItemCount = read32BigEndianInt(labelFile);
  60.     auto imageItemCount = read32BigEndianInt(imageFile);
  61.     if (labelItemCount != imageItemCount)
  62.         return withErrorMessage("Input size of label and image files are different");
  63.     auto itemCount = labelItemCount;
  64.  
  65.     auto rowCount = read32BigEndianInt(imageFile);
  66.     auto colCount = read32BigEndianInt(imageFile);
  67.     if (rowCount != 28 || colCount != 28)
  68.         return withErrorMessage("image size is not 28x28 as expected");
  69.    
  70.     std::vector<Image> dataSet;
  71.     for (int i = 0; i < itemCount; ++i) {
  72.         dataSet.emplace_back(readBytes<1>(labelFile)[0], readBytes<28*28>(imageFile));
  73.     }
  74.  
  75.     return dataSet;
  76. }
  77.  
  78. void showDataset(const std::vector<Image>& dataSet) {
  79.     for (auto& digit : dataSet) {
  80.         std::cout << digit << '\n';
  81.     }
  82. }
  83. std::vector<ImageObject> getImages() {
  84.     auto dataSet = readDataSet("train-labels.idx1-ubyte", "train-images.idx3-ubyte");
  85.     //showDataset(dataSet); // uncomment to see dataset
  86.  
  87.     std::vector<ImageObject> images;
  88.     std::vector<std::vector<int>> image2d(28);
  89.     for (auto& row : image2d) {
  90.         row.resize(28);
  91.     }
  92.     for (auto& image : dataSet) {
  93.         for (int i = 0; i < 28; ++i) {
  94.             for (int j = 0; j < 28; ++j) {
  95.                 image2d[i][j] = image.at(i, j);
  96.             }
  97.         }
  98.         images.emplace_back(image.label, image2d);
  99.     }
  100.  
  101.     return images;
  102. }
  103.  
  104. int main() {
  105.     auto images = getImages();
  106.     std::vector<ImageObject> trainImages(images.begin(), images.begin() + images.size() * 0.8);
  107.     std::vector<ImageObject> testImages(images.begin() + images.size() * 0.8, images.end());
  108.  
  109.     // here train using trainImages and test using testImages
  110. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement