Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*
- * To change this license header, choose License Headers in Project Properties.
- * To change this template file, choose Tools | Templates
- * and open the template in the editor.
- */
- /*
- * File: main.cpp
- * Author: ozair
- *
- * Created on March 29, 2016, 1:33 PM
- */
- #include <iostream>
- using namespace std;
- #include <mlpack/core.hpp>
- #include <mlpack/methods/ann/ffn.hpp>
- #include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
- #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
- #include <mlpack/core/optimizers/sgd/sgd.hpp>
- #include <mlpack/methods/ann/init_rules/random_init.hpp>
- #include <mlpack/methods/ann/layer/bias_layer.hpp>
- #include <mlpack/methods/ann/layer/linear_layer.hpp>
- #include <mlpack/methods/ann/layer/base_layer.hpp>
- #include <mlpack/methods/ann/layer/multiclass_classification_layer.hpp>
- #include <mlpack/methods/ann/layer/dropout_layer.hpp>
- #include <mlpack/methods/ann/layer/dropconnect_layer.hpp>
- using namespace mlpack;
- using namespace mlpack::ann;
- using namespace mlpack::optimization;
- struct ConfusionMatrix {
- double accuracy;
- double precision;
- double recall;
- };
- void printMatAttr(char const* MatName, const arma::mat &data) {
- cout << "Attributes of " << MatName << '\n';
- cout << "# of rows: " << data.n_rows << '\n';
- cout << "# of cols: " << data.n_cols << '\n';
- cout << "# of elements: " << data.n_elem << "\n\n";
- }
- ConfusionMatrix getClassificationError(const arma::mat &prediction, const arma::mat &labels, int positiveClass, double confidenceThreshold) {
- ConfusionMatrix attr;
- size_t error = 0, tp = 0;
- arma::mat predictionClass(prediction.n_rows, prediction.n_cols);
- for (size_t i = 0; i < labels.n_cols; i++) {
- predictionClass(0, i) = prediction(0, i) > confidenceThreshold ? 1 : 0;
- if (arma::sum(arma::abs(predictionClass.col(i) - labels.col(i))) != 0)
- error++;
- else if (arma::sum(predictionClass.col(i) == positiveClass) == 1 && arma::sum(labels.col(i) == positiveClass) == 1)
- tp++;
- }
- attr.accuracy = ((labels.n_cols - error) / (double) labels.n_cols) * 100;
- attr.precision = tp / (double) (arma::sum(predictionClass.row(0) == positiveClass));
- attr.recall = tp / (double) (arma::sum(labels.row(0) == positiveClass));
- return attr;
- }
- int main(int argc, char** argv) {
- arma::mat trainingData, trainingLabels, testData, testLabels, testprediction, trainprediction;
- size_t HiddenLayerSize, InputLayerSize, NumOfSamples, NumOfTrainingSamples, NumOfTestSamples, NumOfFeatures;
- double trainingRatio = 0.75;
- trainingData.load("snpbinary.csv");
- NumOfSamples = trainingData.n_rows;
- NumOfFeatures = trainingData.n_cols - 1;
- arma::arma_rng::set_seed_random();
- trainingData = arma::shuffle(trainingData);
- cout << "Number of samples: " << NumOfSamples << "\n\n";
- trainingLabels = trainingData.col(NumOfFeatures);
- trainingData.shed_col(NumOfFeatures);
- NumOfTrainingSamples = NumOfSamples * trainingRatio;
- NumOfTestSamples = NumOfSamples - NumOfTrainingSamples;
- testData = trainingData.rows(0, NumOfTestSamples - 1);
- testLabels = trainingLabels.rows(0, NumOfTestSamples - 1);
- trainingData.shed_rows(0, NumOfTestSamples - 1);
- trainingLabels.shed_rows(0, NumOfTestSamples - 1);
- printMatAttr("TrainingData", trainingData);
- printMatAttr("TrainingLabels", trainingLabels);
- printMatAttr("TestData", testData);
- printMatAttr("TestLabels", testLabels);
- trainingLabels = trans(trainingLabels);
- trainingData = trans(trainingData);
- testData = trans(testData);
- testLabels = trans(testLabels);
- InputLayerSize = NumOfFeatures;
- HiddenLayerSize = NumOfFeatures + (NumOfFeatures / 2);
- LinearLayer<> inputLayer(NumOfFeatures, HiddenLayerSize);
- BiasLayer<> inputBiasLayer(HiddenLayerSize);
- BaseLayer<LogisticFunction> inputBaseLayer;
- LinearLayer<> hiddenLayer1(HiddenLayerSize, HiddenLayerSize);
- BiasLayer<> hiddenBiasLayer1(HiddenLayerSize);
- BaseLayer<LogisticFunction> hiddenBaseLayer1;
- LinearLayer<> hiddenLayer2(HiddenLayerSize, HiddenLayerSize);
- BiasLayer<> hiddenBiasLayer2(HiddenLayerSize);
- BaseLayer<LogisticFunction> hiddenBaseLayer2;
- LinearLayer<> hiddenLayer3(HiddenLayerSize, HiddenLayerSize);
- BiasLayer<> hiddenBiasLayer3(HiddenLayerSize);
- BaseLayer<LogisticFunction> hiddenBaseLayer3;
- LinearLayer<> hiddenLayer4(HiddenLayerSize, HiddenLayerSize);
- BiasLayer<> hiddenBiasLayer4(HiddenLayerSize);
- BaseLayer<LogisticFunction> hiddenBaseLayer4;
- LinearLayer<> outputLayer(HiddenLayerSize, 1);
- BiasLayer<> outputBiasLayer(1);
- //DropConnectLayer<decltype(outputLayer)> dropConnectLayer0(outputLayer);
- BaseLayer<LogisticFunction> outputBaseLayer;
- DropoutLayer<> dropoutLayer0;
- MulticlassClassificationLayer classOutputLayer;
- auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer,
- //hiddenLayer1, hiddenBiasLayer1, hiddenBaseLayer1,
- //hiddenLayer2, hiddenBiasLayer2, hiddenBaseLayer2,
- //hiddenLayer3, hiddenBiasLayer3, hiddenBaseLayer3,
- hiddenLayer4, hiddenBiasLayer4, hiddenBaseLayer4,
- outputLayer, outputBiasLayer, outputBaseLayer);
- FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
- MeanSquaredErrorFunction> NN(modules, classOutputLayer);
- SGD<decltype(NN)> opt(NN, 0.01, 100000, 1e-5, true);
- cout << "Begin Training\n";
- NN.Train(trainingData, trainingLabels, opt);
- NN.Predict(trainingData, trainprediction);
- NN.Predict(testData, testprediction);
- arma::Col<double> confidenceList = arma::linspace<arma::Col<double>>(0.05, 0.95, 36);
- ConfusionMatrix attr;
- for (size_t t = 0; t < confidenceList.n_elem; t++) {
- cout << confidenceList(t) << ", ";
- attr = getClassificationError(trainprediction, trainingLabels, 1, confidenceList(t));
- cout << attr.accuracy << ", ";
- cout << attr.precision << ", ";
- cout << attr.recall << " ... ";
- cout << confidenceList(t) << ", ";
- attr = getClassificationError(testprediction, testLabels, 1, confidenceList(t));
- cout << attr.accuracy << ", ";
- cout << attr.precision << ", ";
- cout << attr.recall << "\n";
- }
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement