Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <mlpack/core.hpp>
- #include <mlpack/methods/logistic_regression/logistic_regression.hpp>
- #include <mlpack/core/optimizers/sgd/test_function.hpp>
- #include <mlpack/core/optimizers/sgd/update_policies/vanilla_update.hpp>
- #include <mlpack/methods/ann/layer/layer.hpp>
- #include <mlpack/methods/ann/ffn.hpp>
- #include "cmaes.hpp"
- using namespace std;
- using namespace arma;
- using namespace mlpack;
- using namespace mlpack::ann;
- using namespace mlpack::optimization;
- using namespace mlpack::optimization::test;
- using namespace mlpack::distribution;
- using namespace mlpack::regression;
- /**
- * Train and evaluate a vanilla network with the specified structure.
- */
- template<typename MatType = arma::mat>
- void BuildVanillaNetwork(MatType& trainData,
- MatType& trainLabels,
- MatType& testData,
- MatType& testLabels,
- const size_t outputSize,
- const size_t hiddenLayerSize,
- const size_t maxEpochs,
- const double classificationErrorThreshold)
- {
- FFN<NegativeLogLikelihood<> > model;
- model.Add<Linear<> >(trainData.n_rows, hiddenLayerSize);
- model.Add<SigmoidLayer<> >();
- model.Add<Linear<> >(hiddenLayerSize, outputSize);
- model.Add<LogSoftMax<> >();
- int dim = trainData.n_rows * hiddenLayerSize * outputSize;
- arma::mat start1(dim, 1); start1.randu();
- arma::mat initialStdDeviations1(dim, 1); initialStdDeviations1.randu();
- CMAES opt(dim, start1, initialStdDeviations1, 50000, 1e-8);
- model.Train(trainData, trainLabels, opt);
- MatType predictionTemp;
- model.Predict(testData, predictionTemp);
- MatType prediction = arma::zeros<MatType>(1, predictionTemp.n_cols);
- for (size_t i = 0; i < predictionTemp.n_cols; ++i)
- {
- prediction(i) = arma::as_scalar(arma::find(
- arma::max(predictionTemp.col(i)) == predictionTemp.col(i), 1)) + 1;
- }
- size_t error = 0;
- for (size_t i = 0; i < testData.n_cols; i++)
- {
- if (int(arma::as_scalar(prediction.col(i))) ==
- int(arma::as_scalar(testLabels.col(i))))
- {
- error++;
- }
- }
- double classificationError = 1 - double(error) / testData.n_cols;
- cout << "require " << classificationError << " <= " << classificationErrorThreshold << endl;
- }
- int main()
- {
- mlpack::math::RandomSeed(std::time(NULL));
- // SGD TEST CASE PASS
- SGDTestFunction test;
- size_t N = test.NumFunctions();
- arma::mat start(N,1); start.fill(0.5);
- arma::mat initialStdDeviations(N,1); initialStdDeviations.fill(1.5);
- CMAES s(N,start,initialStdDeviations,10000,1e-18);
- arma::mat coordinates(N,1);
- double result = s.Optimize(test, coordinates);
- cout <<
- "BOOST_REQUIRE_CLOSE(result, -1.0, 0.05); \n" <<
- "BOOST_REQUIRE_SMALL(coordinates[0], 1e-3); \n" <<
- "BOOST_REQUIRE_SMALL(coordinates[1], 1e-7); \n" <<
- "BOOST_REQUIRE_SMALL(coordinates[2], 1e-7);" << endl;
- cout << endl << result << endl;
- cout << coordinates[0] << endl;
- cout << coordinates[1] << endl;
- cout << coordinates[2] << endl;
- // Generate a two-Gaussian dataset.
- GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
- GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
- arma::mat data(3, 1000);
- arma::Row<size_t> responses(1000);
- for (size_t i = 0; i < 500; ++i)
- {
- data.col(i) = g1.Random();
- responses[i] = 0;
- }
- for (size_t i = 500; i < 1000; ++i)
- {
- data.col(i) = g2.Random();
- responses[i] = 1;
- }
- // Shuffle the dataset.
- arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0,
- data.n_cols - 1, data.n_cols));
- arma::mat shuffledData(3, 1000);
- arma::Row<size_t> shuffledResponses(1000);
- for (size_t i = 0; i < data.n_cols; ++i)
- {
- shuffledData.col(i) = data.col(indices[i]);
- shuffledResponses[i] = responses[indices[i]];
- }
- // Create a test set.
- arma::mat testData(3, 1000);
- arma::Row<size_t> testResponses(1000);
- for (size_t i = 0; i < 500; ++i)
- {
- testData.col(i) = g1.Random();
- testResponses[i] = 0;
- }
- for (size_t i = 500; i < 1000; ++i)
- {
- testData.col(i) = g2.Random();
- testResponses[i] = 1;
- }
- int dim = shuffledData.n_rows + 1;
- arma::mat start1(dim, 1); start1.fill(0.5);
- arma::mat initialStdDeviations1(dim, 1); initialStdDeviations1.fill(1.5);
- CMAES test1(dim, start1, initialStdDeviations1, 50000, 1e-8);
- LogisticRegression<arma::mat> lr(shuffledData, shuffledResponses, test1, 0.5);
- // Ensure that the error is close to zero.
- const double acc = lr.ComputeAccuracy(data, responses);
- cout << "got this value = " << acc << " should be = 100.0 with tolerance = 0.3" << endl; // 0.3% error tolerance.
- const double testAcc = lr.ComputeAccuracy(testData, testResponses);
- cout << "got this value = " << testAcc << " should be = 100.0 with tolerance = 0.3" << endl;
- // Load the dataset.
- arma::mat dataset;
- data::Load("thyroid_train.csv", dataset, true);
- arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
- dataset.n_cols - 1);
- arma::mat trainLabelsTemp = dataset.submat(dataset.n_rows - 3, 0,
- dataset.n_rows - 1, dataset.n_cols - 1);
- arma::mat trainLabels = arma::zeros<arma::mat>(1, trainLabelsTemp.n_cols);
- for (size_t i = 0; i < trainLabelsTemp.n_cols; ++i)
- {
- trainLabels(i) = arma::as_scalar(arma::find(
- arma::max(trainLabelsTemp.col(i)) == trainLabelsTemp.col(i), 1)) + 1;
- }
- data::Load("thyroid_test.csv", dataset, true);
- arma::mat testData1 = dataset.submat(0, 0, dataset.n_rows - 4,
- dataset.n_cols - 1);
- arma::mat testLabelsTemp = dataset.submat(dataset.n_rows - 3, 0,
- dataset.n_rows - 1, dataset.n_cols - 1);
- arma::mat testLabels = arma::zeros<arma::mat>(1, testLabelsTemp.n_cols);
- for (size_t i = 0; i < testLabels.n_cols; ++i)
- {
- testLabels(i) = arma::as_scalar(arma::find(
- arma::max(testLabelsTemp.col(i)) == testLabelsTemp.col(i), 1)) + 1;
- }
- // Vanilla neural net with logistic activation function.
- // Because 92 percent of the patients are not hyperthyroid the neural
- // network must be significant better than 92%.
- BuildVanillaNetwork<>
- (trainData, trainLabels, testData1, testLabels, 3, 8, 70, 0.1);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement