Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*
- $(CXX) -std=c++11 -O3 -march=native knn.cpp
- */
- #include <cstdint>
- #include <vector>
- #include <valarray>
- #include <fstream>
- #include <iostream>
- #include <numeric>
- #include <algorithm>
- #include <boost/tokenizer.hpp>
- #include <boost/algorithm/string.hpp>
- namespace {
- struct label_pixel {
- int label;
- std::valarray<int> pixels;
- };
- std::vector<label_pixel> slurp_file(std::string const& name) {
- std::ifstream file{name};
- std::vector<label_pixel> result;
- std::string line;
- getline(file, line); // skip first line
- while(getline(file, line)) {
- using boost::algorithm::trim;
- using boost::tokenizer;
- trim(line);
- std::vector<int> integers;
- for(auto const& elt : tokenizer<>{line}) {
- integers.push_back(stoi(elt));
- }
- if( integers.size() < 2 ) {
- throw std::runtime_error("bad input");
- }
- result.push_back({integers[0], {integers.data(), integers.size() - 1}});
- }
- return result;
- }
- int distance(std::valarray<int> const& x, std::valarray<int> const& y) {
- return ((x-y)*(x-y)).sum();
- }
- // STL is missing a function like this
- template<typename FwdIt, typename Metric>
- inline FwdIt min_by(FwdIt first, const FwdIt last, Metric m) {
- if( first == last )
- return last;
- FwdIt min = first++;
- auto dist = m(*min);
- while( first != last ) {
- const auto d = m(*first);
- if( d < dist ) {
- dist = d;
- min = first;
- }
- ++first;
- }
- return min;
- }
- int classify(std::vector<label_pixel> const& training, label_pixel const& input) {
- return min_by(begin(training), end(training), [&](label_pixel const& sample) {
- return distance(sample.pixels, input.pixels);
- })->label;
- }
- } // namespace
- int main() try {
- const auto training = slurp_file("trainingsample.csv");
- const auto validation = slurp_file("validationsample.csv");
- const auto num_correct = count_if(begin(validation), end(validation),
- [&](label_pixel const& sample) {
- return classify(training, sample) == sample.label;
- });
- std::cout << "Percentage correct: "
- << (100 * double(num_correct) / validation.size())
- << std::endl;
- return 0;
- } catch(std::exception const& e) {
- std::cout << "Exception:" << e.what() << std::endl;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement