Advertisement
Guest User

C++11 k-NN

a guest
Jun 10th, 2014
162
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 2.19 KB | None | 0 0
  1. /*
  2.     $(CXX) -std=c++11 -O3 -march=native knn.cpp
  3. */
  4. #include <cstdint>
  5. #include <vector>
  6. #include <valarray>
  7. #include <fstream>
  8. #include <iostream>
  9. #include <numeric>
  10. #include <algorithm>
  11. #include <boost/tokenizer.hpp>
  12. #include <boost/algorithm/string.hpp>
  13.  
  14. namespace {
  15.  
  16. struct label_pixel {
  17.     int label;
  18.     std::valarray<int> pixels;
  19. };
  20.  
  21. std::vector<label_pixel> slurp_file(std::string const& name) {
  22.     std::ifstream file{name};
  23.     std::vector<label_pixel> result;
  24.     std::string line;
  25.  
  26.     getline(file, line); // skip first line
  27.     while(getline(file, line)) {
  28.         using boost::algorithm::trim;
  29.         using boost::tokenizer;
  30.  
  31.         trim(line);
  32.         std::vector<int> integers;
  33.         for(auto const& elt : tokenizer<>{line}) {
  34.             integers.push_back(stoi(elt));
  35.         }
  36.  
  37.         if( integers.size() < 2 ) {
  38.             throw std::runtime_error("bad input");
  39.         }
  40.  
  41.         result.push_back({integers[0], {integers.data(), integers.size() - 1}});
  42.     }
  43.     return result;
  44. }
  45.  
  46. int distance(std::valarray<int> const& x, std::valarray<int> const& y) {
  47.     return ((x-y)*(x-y)).sum();
  48. }
  49.  
  50. // STL is missing a function like this
  51. template<typename FwdIt, typename Metric>
  52. inline FwdIt min_by(FwdIt first, const FwdIt last, Metric m) {
  53.     if( first == last )
  54.         return last;
  55.     FwdIt min  = first++;
  56.     auto  dist = m(*min);
  57.     while( first != last ) {
  58.         const auto d = m(*first);
  59.         if( d < dist ) {
  60.             dist = d;
  61.             min  = first;
  62.         }
  63.         ++first;
  64.     }
  65.     return min;
  66. }
  67.  
  68. int classify(std::vector<label_pixel> const& training, label_pixel const& input) {
  69.     return min_by(begin(training), end(training), [&](label_pixel const& sample) {
  70.         return distance(sample.pixels, input.pixels);
  71.     })->label;
  72. }
  73.  
  74. } // namespace
  75.  
  76. int main() try {
  77.     const auto training = slurp_file("trainingsample.csv");
  78.     const auto validation = slurp_file("validationsample.csv");
  79.  
  80.     const auto num_correct = count_if(begin(validation), end(validation),
  81.     [&](label_pixel const& sample) {
  82.         return classify(training, sample) == sample.label;
  83.     });
  84.  
  85.     std::cout << "Percentage correct: "
  86.               << (100 * double(num_correct) / validation.size())
  87.               << std::endl;
  88.     return 0;
  89. } catch(std::exception const& e) {
  90.     std::cout << "Exception:" << e.what() << std::endl;
  91. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement