Advertisement
Guest User

Untitled

a guest
Apr 27th, 2015
198
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 2.47 KB | None | 0 0
  1. #include <mlpack/core.hpp>
  2. #include <mlpack/methods/cf/cf.hpp>
  3. #include <mlpack/methods/regularized_svd/regularized_svd.hpp>
  4.  
  5. using namespace mlpack;
  6. using namespace mlpack::cf;
  7. using namespace mlpack::svd;
  8. using namespace std;
  9.  
  10. PARAM_STRING_REQ("input_file", "Input ratings.", "i");
  11. PARAM_STRING_REQ("test_file", "Test ratings.", "t");
  12. PARAM_STRING_REQ("algorithm", "Algorithm to use: 'regsvd', 'nmf'.", "a");
  13.  
  14. PARAM_INT("rank", "Rank of decomposed matrices.", "R", 2);
  15. PARAM_INT("neighborhood", "Number of users in neighborhood.", "n", 5);
  16.  
  17. PARAM_INT("iterations", "Maximum number of iterations.", "m", 100);
  18. PARAM_DOUBLE("lambda", "Regularization for regularized SVD.", "l", 0.01);
  19. PARAM_DOUBLE("alpha", "Learning rate for regularized SVD.", "a", 0.01);
  20.  
  21. int main(int argc, char** argv)
  22. {
  23.   CLI::ParseCommandLine(argc, argv);
  24.  
  25.   arma::mat dataset;
  26.   const string inputFile = CLI::GetParam<string>("input_file");
  27.   data::Load(inputFile, dataset, true);
  28.  
  29.   arma::mat testSet;
  30.   const string testFile = CLI::GetParam<string>("test_file");
  31.   data::Load(testFile, testSet, true);
  32.  
  33.   const string algorithm = CLI::GetParam<string>("algorithm");
  34.   const size_t rank = (size_t) CLI::GetParam<int>("rank");
  35.   const size_t neighborhood = (size_t) CLI::GetParam<int>("neighborhood");
  36.  
  37.   if (algorithm == "regsvd")
  38.   {
  39.     const size_t iterations = (size_t) CLI::GetParam<int>("iterations");
  40.     const double alpha = CLI::GetParam<double>("alpha");
  41.     const double lambda = CLI::GetParam<double>("lambda");
  42.  
  43.     RegularizedSVD<> rsvd(iterations, alpha, lambda);
  44.  
  45.     Timer::Start("cf_decomposition");
  46.     CF<RegularizedSVD<>> cf(dataset, rsvd, neighborhood, rank);
  47.     Timer::Stop("cf_decomposition");
  48.     Log::Info << "Decomposition complete.\n";
  49.  
  50.     // Compute MSE of predictions.
  51.     double error = 0.0;
  52.     Log::Info.ignoreInput = true;
  53.     for (size_t i = 0; i < testSet.n_cols; ++i)
  54.     {
  55.       if (i % 1000 == 0)
  56.         Log::Warn << "On prediction " << i << " of " << testSet.n_cols << ".\n";
  57.       const double prediction = cf.Predict(testSet(0, i), testSet(1, i));
  58.       error += std::pow(prediction - testSet(2, i), 2.0);
  59.     }
  60.     error = sqrt(error) / testSet.n_cols;
  61.     Log::Info.ignoreInput = false;
  62.  
  63.     Log::Info << "Mean squared error: " << error << "." << endl;
  64.   }
  65.   else if (algorithm == "nmf")
  66.   {
  67.     Log::Fatal << "NMF not implemented yet.\n";
  68.   }
  69.   else
  70.   {
  71.     Log::Fatal << "Unknown algorithm '" << algorithm << "'!" << endl;
  72.   }
  73. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement