Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <cstddef>
- #include <fstream>
- #include <string>
- #include <array>
- struct Image {
- std::array<uint8_t, 28*28> data;
- uint8_t label;
- Image(uint8_t label, const std::array<uint8_t, 28*28>& data) : label(label), data(data) {}
- uint8_t at(int row, int col) const {
- return data[col + row * 28];
- }
- };
- std::ostream& operator<<(std::ostream& out, const Image& image) {
- out << "label = " << static_cast<int>(image.label) << '\n';
- for (int row = 0; row < 28; ++row) {
- for (int col = 0; col < 28; ++col) {
- out << ((image.at(row, col) >= 128) ? '#' : ' ');
- }
- out << '\n';
- }
- return out;
- }
- template<int Size> std::array<uint8_t, Size> readBytes(std::istream& in) {
- std::array<uint8_t, Size> buf;
- in.read((char*)&buf[0], Size);
- return buf;
- }
- uint32_t read32BigEndianInt(std::istream& in) {
- auto buf = readBytes<4>(in);
- return (buf[0] << 24) | (buf[1] << 16) | (buf[2] << 8) | buf[3];
- }
- std::vector<Image> readDataSet(const std::string& labelFileName, const std::string& imageFileName) {
- auto withErrorMessage = [](const std::string& message) -> std::vector<Image> {
- std::cerr << message << '\n';
- return {};
- };
- std::ifstream labelFile(labelFileName, std::ios::binary);
- std::ifstream imageFile(imageFileName, std::ios::binary);
- if (!labelFile)
- return withErrorMessage("Couldn't open label file");
- if (!imageFile)
- return withErrorMessage("Couldn't open image file");
- auto labelMagicNumber = read32BigEndianInt(labelFile);
- auto imageMagicNumber = read32BigEndianInt(imageFile);
- if (labelMagicNumber != 2049)
- return withErrorMessage("Magic number for label file is incorrect");
- if (imageMagicNumber != 2051)
- return withErrorMessage("Magic number for image file is incorrect");
- auto labelItemCount = read32BigEndianInt(labelFile);
- auto imageItemCount = read32BigEndianInt(imageFile);
- if (labelItemCount != imageItemCount)
- return withErrorMessage("Input size of label and image files are different");
- auto itemCount = labelItemCount;
- auto rowCount = read32BigEndianInt(imageFile);
- auto colCount = read32BigEndianInt(imageFile);
- if (rowCount != 28 || colCount != 28)
- return withErrorMessage("image size is not 28x28 as expected");
- std::vector<Image> dataSet;
- for (int i = 0; i < itemCount; ++i) {
- dataSet.emplace_back(readBytes<1>(labelFile)[0], readBytes<28*28>(imageFile));
- }
- return dataSet;
- }
- void showDataset(const std::vector<Image>& dataSet) {
- for (auto& digit : dataSet) {
- std::cout << digit << '\n';
- }
- }
- std::vector<ImageObject> getImages() {
- auto dataSet = readDataSet("train-labels.idx1-ubyte", "train-images.idx3-ubyte");
- //showDataset(dataSet); // uncomment to see dataset
- std::vector<ImageObject> images;
- std::vector<std::vector<int>> image2d(28);
- for (auto& row : image2d) {
- row.resize(28);
- }
- for (auto& image : dataSet) {
- for (int i = 0; i < 28; ++i) {
- for (int j = 0; j < 28; ++j) {
- image2d[i][j] = image.at(i, j);
- }
- }
- images.emplace_back(image.label, image2d);
- }
- return images;
- }
- int main() {
- auto images = getImages();
- std::vector<ImageObject> trainImages(images.begin(), images.begin() + images.size() * 0.8);
- std::vector<ImageObject> testImages(images.begin() + images.size() * 0.8, images.end());
- // here train using trainImages and test using testImages
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement