Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <mlpack/core.hpp>
- #include <mlpack/methods/cf/cf.hpp>
- #include <mlpack/methods/regularized_svd/regularized_svd.hpp>
- using namespace mlpack;
- using namespace mlpack::cf;
- using namespace mlpack::svd;
- using namespace std;
- PARAM_STRING_REQ("input_file", "Input ratings.", "i");
- PARAM_STRING_REQ("test_file", "Test ratings.", "t");
- PARAM_STRING_REQ("algorithm", "Algorithm to use: 'regsvd', 'nmf'.", "a");
- PARAM_INT("rank", "Rank of decomposed matrices.", "R", 2);
- PARAM_INT("neighborhood", "Number of users in neighborhood.", "n", 5);
- PARAM_INT("iterations", "Maximum number of iterations.", "m", 100);
- PARAM_DOUBLE("lambda", "Regularization for regularized SVD.", "l", 0.01);
- PARAM_DOUBLE("alpha", "Learning rate for regularized SVD.", "a", 0.01);
- int main(int argc, char** argv)
- {
- CLI::ParseCommandLine(argc, argv);
- arma::mat dataset;
- const string inputFile = CLI::GetParam<string>("input_file");
- data::Load(inputFile, dataset, true);
- arma::mat testSet;
- const string testFile = CLI::GetParam<string>("test_file");
- data::Load(testFile, testSet, true);
- const string algorithm = CLI::GetParam<string>("algorithm");
- const size_t rank = (size_t) CLI::GetParam<int>("rank");
- const size_t neighborhood = (size_t) CLI::GetParam<int>("neighborhood");
- if (algorithm == "regsvd")
- {
- const size_t iterations = (size_t) CLI::GetParam<int>("iterations");
- const double alpha = CLI::GetParam<double>("alpha");
- const double lambda = CLI::GetParam<double>("lambda");
- RegularizedSVD<> rsvd(iterations, alpha, lambda);
- Timer::Start("cf_decomposition");
- CF<RegularizedSVD<>> cf(dataset, rsvd, neighborhood, rank);
- Timer::Stop("cf_decomposition");
- Log::Info << "Decomposition complete.\n";
- // Compute MSE of predictions.
- double error = 0.0;
- Log::Info.ignoreInput = true;
- for (size_t i = 0; i < testSet.n_cols; ++i)
- {
- if (i % 1000 == 0)
- Log::Warn << "On prediction " << i << " of " << testSet.n_cols << ".\n";
- const double prediction = cf.Predict(testSet(0, i), testSet(1, i));
- error += std::pow(prediction - testSet(2, i), 2.0);
- }
- error = sqrt(error) / testSet.n_cols;
- Log::Info.ignoreInput = false;
- Log::Info << "Mean squared error: " << error << "." << endl;
- }
- else if (algorithm == "nmf")
- {
- Log::Fatal << "NMF not implemented yet.\n";
- }
- else
- {
- Log::Fatal << "Unknown algorithm '" << algorithm << "'!" << endl;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement