Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- * GLAVNI NN KOD
- */
- public class NNVezba implements LearningEventListener{
- int inputCount = 13;
- int outputCount = 3;
- double[] lr = {0.2, 0.4, 0.6};
- ArrayList<Training> trainings = new ArrayList<>();
- /**
- * @param args the command line arguments
- */
- public static void main(String[] args) {
- // TODO code application logic here
- (new NNVezba()).run();
- }
- @Override
- public void handleLearningEvent(LearningEvent event) {
- BackPropagation bp = (BackPropagation) event.getSource();
- System.out.println("iteration: " +bp.getCurrentIteration()+" "
- + "Total network error: " +bp.getTotalNetworkError());
- }
- private void run() {
- String filePath = "wines.csv";
- DataSet dataSet = DataSet.createFromFile(filePath, inputCount, outputCount, ",");
- Normalizer norm = new MaxNormalizer(dataSet);
- norm.normalize(dataSet);
- dataSet.shuffle();
- DataSet[] trainTest = dataSet.split(0.7, 0.3);
- DataSet trainSet = trainTest[0];
- DataSet testSet = trainTest[1];
- int numOfIterations = 0;
- int numOfTrainings = 0;
- for (double l : lr) {
- MultiLayerPerceptron nn = new MultiLayerPerceptron(inputCount, 22, outputCount);
- BackPropagation learningRule = nn.getLearningRule();
- learningRule.addListener(this);
- learningRule.setMaxError(0.02);
- learningRule.setLearningRate(l);
- learningRule.setMaxIterations(1000);
- nn.learn(trainSet);
- numOfTrainings++;
- numOfIterations += learningRule.getCurrentIteration();
- double accuracy = evaluateAccuracy(nn, testSet);
- Training t = new Training(nn, accuracy);
- trainings.add(t);
- }
- System.out.println("Srednja vrednosti broja iteracija :" +numOfTrainings/numOfIterations);
- SaveNetWithMaxAccuracy();
- }
- private double evaluateAccuracy(MultiLayerPerceptron nn, DataSet testSet) {
- ConfMatrix cmatrix = new ConfMatrix(outputCount);
- double accuracy = 0;
- for (DataSetRow dataSetRow : testSet) {
- nn.setInput(dataSetRow.getInput());
- nn.calculate();
- int actual = getMaxIndex(dataSetRow.getDesiredOutput());
- int predicted = getMaxIndex(nn.getOutput());
- cmatrix.incrementElement(actual, predicted);
- }
- for (int i = 0; i < outputCount; i++) {
- accuracy += (double) (cmatrix.getTruePositive(i) + cmatrix.getTrueNegative(i)) / cmatrix.total;
- }
- cmatrix.print();
- System.out.println("accuracy : "+(double) accuracy/outputCount);
- return (double) accuracy / outputCount;
- }
- private void SaveNetWithMaxAccuracy() {
- Training maxTr = trainings.get(0);
- for (Training t : trainings) {
- if(t.getAccuracy() > maxTr.getAccuracy()){
- maxTr = t;
- }
- }
- maxTr.getNn().save("nn.nnet");
- }
- private int getMaxIndex(double[] output) {
- int maxIndex = 0;
- for (int i = 0; i < output.length; i++) {
- if(output[maxIndex] < output[i]){
- maxIndex = i;
- }
- }
- return maxIndex;
- }
- }
- ________________________________________________________________________________________________________________
- /**
- *
- * Training klasa
- */
- package nn.vezba;
- import org.neuroph.core.NeuralNetwork;
- public class Training {
- private NeuralNetwork nn;
- private double accuracy;
- public Training(NeuralNetwork nn, double accuracy) {
- this.nn = nn;
- this.accuracy = accuracy;
- }
- public double getAccuracy() {
- return accuracy;
- }
- public void setAccuracy(double accuracy) {
- this.accuracy = accuracy;
- }
- public NeuralNetwork getNn() {
- return nn;
- }
- public void setNn(NeuralNetwork nn) {
- this.nn = nn;
- }
- }
- ________________________________________________________________________________________________________________
- /**
- *
- * ConfMatrix klasa (vecinom se kopira iz ConfusionMatrix klase)
- */
- public class ConfMatrix {
- int[][] matrix;
- int classCount;
- int total = 0;
- public ConfMatrix(int classCount) {
- this.matrix = new int[classCount][classCount];
- this.classCount = classCount;
- }
- public void incrementElement(int actual, int predicted) {
- matrix[actual][predicted]++;
- total++;
- }
- public int getTruePositive(int cl) {
- return (int)matrix[cl][cl];
- }
- public int getTrueNegative(int cl) {
- int trueNegative = 0;
- for(int i = 0; i < classCount; i++) {
- if (i == cl) continue;
- for(int j = 0; j < classCount; j++) {
- if (j == cl) continue;
- trueNegative += matrix[i][j];
- }
- }
- return trueNegative;
- }
- public int getFalsePositive(int cl) {
- int falsePositive = 0;
- for(int i=0; i<classCount; i++) {
- if (i == cl) continue;
- falsePositive += matrix[i][cl];
- }
- return falsePositive;
- }
- public int getFalseNegative(int cl) {
- int falseNegative = 0;
- for(int i=0; i<classCount; i++) {
- if (i == cl) continue;
- falseNegative += matrix[cl][i];
- }
- return falseNegative;
- }
- public void print(){
- for (int i = 0; i < matrix.length; i++) {
- for (int j = 0; j < matrix.length; j++) {
- System.out.print(matrix[i][j] + " ");
- }
- System.out.println();
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment