AleksaLjujic

IS NN Kod novi

Aug 25th, 2025
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 18.59 KB | Software | 0 0
  1. //Breast cancer
  2.  
  3. public class BrestCancer implements LearningEventListener,NeurophExam{
  4.  
  5.     int inputCount = 30;
  6.     int outputCount = 1;
  7.     DataSet trainData;
  8.     DataSet testData;
  9.     double[] learningRates = {0.2,0.4,0.6};
  10.     int[] hiddenNeurons = {10,20,30};
  11.     int hiddenNeuron;
  12.     double learningRate;
  13.     int trainingCount = 0;
  14.     int totalIterations = 0;
  15.     double momentumValue = 0.7;
  16.     double testSplit = 0.35;
  17.     double trainSplit = 0.65;
  18.    
  19.     ArrayList<Training> trainings = new ArrayList<>();
  20.    
  21.     public static void main(String[] args) {
  22.         new BrestCancer().run();
  23.     }
  24.    
  25.     public void run(){
  26.         DataSet ds = loadDataSet();
  27.         ds = preprocessDataSet(ds);
  28.         DataSet[] trainTest = trainTestSplit(ds);
  29.         trainData = trainTest[0];
  30.         testData = trainTest[1];
  31.        
  32.         for (double lr : learningRates) {
  33.             learningRate = lr;
  34.             for (int hn : hiddenNeurons) {
  35.                 hiddenNeuron = hn;
  36.                 System.out.println("Hidden neurons::"+hiddenNeuron+"::Learning rate::"+learningRate);
  37.                 MultiLayerPerceptron mlp = createNeuralNetwork();
  38.                 mlp = trainNeuralNetwork(mlp, trainData);
  39.                 evaluate(mlp, testData);
  40.             }        
  41.         }
  42.         System.out.println("Average number of iterations per training is: "+(double) totalIterations/trainingCount);
  43.         int i = 1;
  44.         for (Training tr : trainings) {
  45.             System.out.println("Model "+i+"::mse::"+tr.mse);
  46.             i++;
  47.         }
  48.         System.out.println("");
  49.         saveBestNetwork();
  50.        
  51.     }
  52.  
  53.     @Override
  54.     public void handleLearningEvent(LearningEvent le) {
  55.         MomentumBackpropagation mbp = (MomentumBackpropagation) le.getSource();
  56.         System.out.println("interation:: "+mbp.getCurrentIteration()+":: total network error:: "+mbp.getTotalNetworkError());
  57.     }
  58.  
  59.     @Override
  60.     public DataSet loadDataSet() {
  61.         return DataSet.createFromFile("breast_cancer_data.csv", inputCount, outputCount, ",");
  62.     }
  63.  
  64.     @Override
  65.     public DataSet preprocessDataSet(DataSet ds) {
  66.         MaxNormalizer norm = new MaxNormalizer(ds);
  67.         norm.normalize(ds);
  68.         ds.shuffle();
  69.         return ds;
  70.     }
  71.  
  72.     @Override
  73.     public DataSet[] trainTestSplit(DataSet ds) {
  74.         return ds.split(trainSplit, testSplit);
  75.     }
  76.  
  77.     @Override
  78.     public MultiLayerPerceptron createNeuralNetwork() {
  79.         return new MultiLayerPerceptron(inputCount,hiddenNeuron,outputCount);
  80.     }
  81.  
  82.     @Override
  83.     public MultiLayerPerceptron trainNeuralNetwork(MultiLayerPerceptron mlp, DataSet ds) {
  84.         MomentumBackpropagation mbp = (MomentumBackpropagation) mlp.getLearningRule();
  85.         //mbp.addListener(this);
  86.         mbp.setLearningRate(learningRate);
  87.         mbp.setMomentum(momentumValue);
  88.         mbp.setMaxIterations(1000);
  89.         mlp.learn(ds);
  90.        
  91.         trainingCount++;
  92.         totalIterations+=mbp.getCurrentIteration();
  93.        
  94.         return mlp;
  95.     }
  96.  
  97.     @Override
  98.     public void evaluate(MultiLayerPerceptron mlp, DataSet ds) {
  99.         double mse = 0;
  100.         int total = 0;
  101.         for (DataSetRow d : ds) {
  102.             mlp.setInput(d.getInput());
  103.             mlp.calculate();
  104.            
  105.             double[] actual = d.getDesiredOutput();
  106.             double[] predicted = mlp.getOutput();
  107.            
  108.            
  109.             mse += (double) Math.pow(actual[0]-predicted[0],2);
  110.             total++;
  111.         }
  112.         mse = (double) mse / (2*testData.size());
  113.        
  114.         System.out.println("Mean Squared Error:: "+mse);
  115.        
  116.         Training tr = new Training(mlp, mse);
  117.         trainings.add(tr);
  118.        
  119.     }
  120.    
  121.  
  122.     @Override
  123.     public void saveBestNetwork() {
  124.        int minIndex = 0;
  125.        
  126.        for(int i=0;i<trainings.size();i++){
  127.            if(trainings.get(i).mse<trainings.get(minIndex).mse){
  128.                minIndex = i;
  129.            }
  130.        }
  131.        trainings.get(minIndex).mlp.save("bestNN.nnet");
  132.        System.out.println("Best NN is Model "+(minIndex+1)+" with mse of "+trainings.get(minIndex).mse);
  133.        System.out.println("Best NN is saved.");
  134.     }
  135. }
  136.  
  137. //WINES
  138.  
  139. public class Winess implements LearningEventListener,NeurophExam{
  140.    
  141.    
  142.     ArrayList<Training> trainings = new ArrayList<>();
  143.     int inputCount = 13;
  144.     int outputCount = 3;
  145.     int hiddenCount = 22;
  146.     double[] learningRates = {0.2,0.4,0.6};
  147.     DataSet testData;
  148.     DataSet trainData;
  149.     int trainingCount = 0;
  150.     int totalIterations = 0;
  151.     double learningRate;
  152.      
  153.     public static void main(String[] args) {
  154.        new Winess().run();
  155.     }
  156.  
  157.     @Override
  158.     public void handleLearningEvent(LearningEvent le) {
  159.         BackPropagation bp = (BackPropagation) le.getSource();
  160.         System.out.println("iteration::"+bp.getCurrentIteration()+" total network error::"+bp.getTotalNetworkError());
  161.     }
  162.  
  163.     @Override
  164.     public DataSet loadDataSet() {
  165.         DataSet ds = DataSet.createFromFile("wines.csv", inputCount, outputCount, ",");
  166.         return ds;
  167.     }
  168.  
  169.     @Override
  170.     public DataSet preprocessDataSet(DataSet ds) {
  171.         MaxNormalizer norm = new MaxNormalizer(ds);
  172.         norm.normalize(ds);
  173.         ds.shuffle();
  174.         return ds;
  175.     }
  176.  
  177.     @Override
  178.     public DataSet[] trainTestSplit(DataSet ds) {
  179.         return ds.split(0.7,0.3);
  180.     }
  181.  
  182.     @Override
  183.     public MultiLayerPerceptron createNeuralNetwork() {
  184.         return new MultiLayerPerceptron(inputCount,hiddenCount,outputCount);
  185.     }
  186.  
  187.     @Override
  188.     public MultiLayerPerceptron trainNeuralNetwork(MultiLayerPerceptron mlp, DataSet ds) {
  189.         BackPropagation bp = (BackPropagation) mlp.getLearningRule();
  190.         bp.addListener(this);  
  191.         bp.setLearningRate(learningRate);
  192.         bp.setMaxError(0.02);
  193.         bp.setMaxIterations(1000);
  194.         mlp.learn(ds);
  195.        
  196.         trainingCount++;
  197.         totalIterations+=bp.getCurrentIteration();
  198.         return mlp;      
  199.     }
  200.  
  201.     @Override
  202.     public void evaluate(MultiLayerPerceptron mlp, DataSet ds) {
  203.         String[] classLabels = {"c1","c2","c3"};
  204.         ConfusionMatrix cm = new ConfusionMatrix(classLabels);
  205.         double accuracy = 0;
  206.        
  207.         for (DataSetRow d : ds) {
  208.             mlp.setInput(d.getInput());
  209.             mlp.calculate();
  210.             double[] actual = d.getDesiredOutput();
  211.             double[] predicted = mlp.getOutput();
  212.            
  213.             int maxActual = maxIndex(actual);
  214.             int maxPredicted = maxIndex(predicted);
  215.            
  216.             cm.incrementElement(maxActual, maxPredicted);
  217.         }
  218.         for(int i=0;i<outputCount;i++){
  219.             accuracy+=(double)(cm.getTruePositive(i)+cm.getTrueNegative(i))/cm.getTotal();
  220.         }
  221.         accuracy = (double) accuracy/outputCount;
  222.         System.out.println(cm.toString());
  223.         System.out.println("accuracy::"+accuracy);
  224.         Training tr = new Training(mlp, accuracy);
  225.         trainings.add(tr);
  226.      
  227.     }
  228.    
  229.     public int maxIndex(double[] array){
  230.         int maxIndex = 0;
  231.         double maxValue = 0;
  232.        
  233.         for(int i=0;i<array.length;i++){
  234.             if(array[i]>maxValue){
  235.                 maxValue = array[i];
  236.                 maxIndex = i;
  237.             }
  238.         }
  239.        
  240.         return maxIndex;
  241.     }
  242.  
  243.     @Override
  244.     public void saveBestNetwork() {
  245.         int maxIdx = 0;
  246.         for (int i = 1; i < trainings.size(); i++) {
  247.             if (trainings.get(i).getAccuracy() > trainings.get(maxIdx).getAccuracy()) {
  248.                 maxIdx = i;
  249.             }
  250.         }
  251.         MultiLayerPerceptron bestNN = trainings.get(maxIdx).getMlp();
  252.         if(bestNN!=null){
  253.             bestNN.save("bestNN.nnet");
  254.             System.out.println("Best NN is Model "+(maxIdx+1)+":: "+trainings.get(maxIdx).getAccuracy());
  255.             System.out.println("Saving best NN!");
  256.         }else
  257.         {
  258.             System.out.println("Error!");
  259.         }
  260.        
  261.     }
  262.    
  263.     public void run(){
  264.         DataSet ds = loadDataSet();
  265.         ds = preprocessDataSet(ds);
  266.         DataSet[] trainTest = trainTestSplit(ds);
  267.        
  268.         trainData = trainTest[0];
  269.         testData = trainTest[1];
  270.        
  271.         for (double lr : learningRates) {
  272.             learningRate = lr;
  273.             MultiLayerPerceptron mlp = createNeuralNetwork();
  274.             mlp = trainNeuralNetwork(mlp, trainData);
  275.            
  276.             System.out.println("Model learning rate::"+lr);
  277.             evaluate(mlp, testData);
  278.            
  279.         }
  280.         System.out.println("Average number of iterations per training::"+totalIterations/trainingCount);
  281.  
  282.         int i = 1;
  283.         for(Training tr:trainings) {
  284.             System.out.println("Model " +i+ " :: " + tr.accuracy);
  285.             i++;
  286.         }
  287.         saveBestNetwork();
  288.     }
  289. }
  290.  
  291. //GLASS
  292.  
  293. public class Glasss implements NeurophExam,LearningEventListener{
  294.    
  295.     ArrayList<Training> trainings = new ArrayList<>();
  296.     int inputCount = 9;
  297.     int outputCount = 7;
  298.     DataSet trainData;
  299.     DataSet testData;
  300.     int[] hiddenNeurons = {10,20,30};
  301.     double[] learningRates = {0.2,0.4,0.6};
  302.     int hiddenNeuron;
  303.     double learningRate;
  304.     int trainCount = 0;
  305.     int totalIterations = 0;
  306.     double tempAccuracy = 0;
  307.    
  308.     public static void main(String[] args) {
  309.         Glasss glass = new Glasss();
  310.         glass.run();
  311.     }
  312.    
  313.     @Override
  314.     public void handleLearningEvent(LearningEvent event) {
  315.         MomentumBackpropagation mbp = (MomentumBackpropagation) event.getSource();
  316.         System.out.println("interations: "+mbp.getCurrentIteration()+":total network error :"+mbp.getTotalNetworkError());
  317.     }
  318.  
  319.     @Override
  320.     public DataSet loadDataSet() {
  321.         DataSet ds = DataSet.createFromFile("glass.csv", inputCount, outputCount, ",");
  322.        
  323.         return ds;
  324.     }
  325.  
  326.     @Override
  327.     public DataSet preprocessDataSet(DataSet ds) {
  328.         MaxNormalizer norm = new MaxNormalizer(ds);
  329.         norm.normalize(ds);
  330.         ds.shuffle();
  331.         return ds;
  332.     }
  333.  
  334.     @Override
  335.     public DataSet[] trainTestSplit(DataSet ds) {
  336.         return ds.split(0.65,0.35);
  337.     }
  338.  
  339.     @Override
  340.     public MultiLayerPerceptron createNeuralNetwork() {
  341.         return new MultiLayerPerceptron(inputCount,hiddenNeuron,outputCount);
  342.     }
  343.  
  344.     @Override
  345.     public MultiLayerPerceptron trainNeuralNetwork(MultiLayerPerceptron mlp, DataSet ds) {
  346.         MomentumBackpropagation mbp = (MomentumBackpropagation) mlp.getLearningRule();
  347.        
  348.         mbp.addListener(this);
  349.         mbp.setMomentum(0.6);
  350.         mbp.setLearningRate(learningRate);
  351.         mbp.setMaxIterations(1000);
  352.         mlp.learn(ds);
  353.        
  354.         totalIterations += mbp.getCurrentIteration();
  355.         trainCount++;
  356.        
  357.         return mlp;
  358.     }
  359.  
  360.     @Override
  361.     public void evaluate(MultiLayerPerceptron mlp, DataSet ds) {
  362.         String[] classLabels = {"c1","c2","c3","c4","c5","c6","c7"};
  363.         ConfusionMatrix cm = new ConfusionMatrix(classLabels);
  364.         double accuracy = 0;
  365.        
  366.         for (DataSetRow row : ds) {
  367.             mlp.setInput(row.getInput());
  368.             mlp.calculate();
  369.            
  370.             double[] actual = row.getDesiredOutput();
  371.             double[] predicted = mlp.getOutput();
  372.            
  373.             int classActual = returnMaxIndex(actual);
  374.             int classPredicted = returnMaxIndex(predicted);
  375.            
  376.             cm.incrementElement(classActual, classPredicted);
  377.         }
  378.        
  379.         for(int i=0;i<outputCount;i++){
  380.             accuracy += (double) (cm.getTruePositive(i) + cm.getTrueNegative(i))/cm.getTotal();
  381.         }
  382.         accuracy = (double) accuracy/outputCount;
  383.         System.out.println("NN::hidden neurons::"+hiddenNeuron+"::learning rate::"+learningRate);
  384.         System.out.println(cm.toString());
  385.         System.out.println("accuracy::"+accuracy);
  386.         Training tr = new Training(mlp, accuracy);
  387.         trainings.add(tr);
  388.     }
  389.    
  390.     public int returnMaxIndex(double[] array){
  391.         double maxValue = 0;
  392.         int maxIndex = 0;
  393.        
  394.         for(int i=0;i<array.length;i++){
  395.             if(array[i]>maxValue){
  396.                 maxValue = array[i];
  397.                 maxIndex = i;
  398.             }
  399.         }
  400.         return maxIndex;
  401.     }
  402.  
  403.     @Override
  404.     public void saveBestNetwork() {
  405.         double maxAcc = 0;
  406.         MultiLayerPerceptron bestNN = null;
  407.         for (Training tr : trainings) {
  408.             if(tr.accuracy>maxAcc){
  409.                 maxAcc = tr.accuracy;
  410.                 bestNN = tr.mlp;
  411.             }
  412.         }
  413.         if(bestNN!=null){
  414.             bestNN.save("bestNN.nnet");
  415.             System.out.println("Najbolja NN ima accuracy: "+maxAcc);
  416.             System.out.println("Sacuvana je najbolja NN!");
  417.         }else
  418.         {
  419.             System.out.println("Greska!");
  420.         }
  421.        
  422.     }
  423.    
  424.     public void run(){
  425.         DataSet ds = loadDataSet();
  426.         ds = preprocessDataSet(ds);
  427.         DataSet[] trainTest = trainTestSplit(ds);
  428.         trainData = trainTest[0];
  429.         testData = trainTest[1];
  430.        
  431.         for (double i: learningRates) {
  432.             for (int j : hiddenNeurons) {
  433.                 learningRate = i;
  434.                 hiddenNeuron = j;
  435.                 MultiLayerPerceptron mlp = createNeuralNetwork();                
  436.                 mlp = trainNeuralNetwork(mlp, trainData);
  437.                
  438.                 evaluate(mlp, testData);
  439.                                
  440.             }
  441.         }        
  442.                
  443.         System.out.println("Srednja vrednost iteracija potrebnih za trening svih mreza: "+totalIterations/trainCount);
  444.        
  445.         System.out.println("Modeli:");
  446.         int i = 1;
  447.         for(Training tr:trainings) {
  448.             System.out.println("Model " +i+ " :: " + tr.accuracy);
  449.             i++;
  450.         }
  451.        
  452.         saveBestNetwork();
  453.  
  454.     }
  455. }
  456.  
  457. //DIABETESS
  458.  
  459. public class MojZadatak implements NeurophExam, LearningEventListener {
  460.    
  461.     Map<Double, MultiLayerPerceptron> mapa = new HashMap<>();
  462.     int inputCount = 8;
  463.     int outputCount = 1;
  464.     DataSet trainSet;
  465.     DataSet testSet;
  466.     double[] learningRates = {0.2, 0.3, 0.4};
  467.     //ArrayList<Training> trainings = new ArrayList<>();
  468.  
  469.     /**
  470.      * U ovoj metodi pozivati sve metode koje cete implementirati iz NeurophExam
  471.      * interfejsa
  472.      */
  473.     private void run() {
  474.         DataSet ds = loadDataSet();
  475.         ds = preprocessDataSet(ds);
  476.         DataSet[] trainAndTest = trainTestSplit(ds);
  477.         trainSet = trainAndTest[0];
  478.         testSet = trainAndTest[1];
  479.         MultiLayerPerceptron neuralNet = createNeuralNetwork();
  480.         trainNeuralNetwork(neuralNet, ds);
  481.         saveBestNetwork();
  482.         int i = 1;
  483.         for (Map.Entry<Double, MultiLayerPerceptron> entry : mapa.entrySet()) {
  484.             System.out.println("Model " +i+ " :: " + entry.getKey());
  485.             i++;
  486.         }
  487.     }
  488.  
  489.     @Override
  490.     public DataSet loadDataSet() {
  491.         DataSet dataSet = DataSet.createFromFile("diabetes_data.csv", inputCount, outputCount, ",");
  492.         return dataSet;
  493.     }
  494.  
  495.     @Override
  496.     public DataSet preprocessDataSet(DataSet ds) {
  497.         Normalizer norm = new MaxNormalizer(ds);
  498.         norm.normalize(ds);
  499.         ds.shuffle();
  500.         return ds;
  501.     }
  502.  
  503.     @Override
  504.     public DataSet[] trainTestSplit(DataSet ds) {
  505.         return ds.split(0.6, 0.4);
  506.     }
  507.  
  508.     @Override
  509.     public MultiLayerPerceptron createNeuralNetwork() {
  510.         return new MultiLayerPerceptron(inputCount, 20, 16, outputCount);
  511.     }
  512.  
  513.     @Override
  514.     public MultiLayerPerceptron trainNeuralNetwork(MultiLayerPerceptron mlp, DataSet ds) {
  515.  
  516.         int numOfIterations = 0;
  517.         int numOfTrainings = 0;
  518.  
  519.         for (double lr : learningRates) {
  520.             MomentumBackpropagation learningRule = (MomentumBackpropagation) mlp.getLearningRule();
  521.             learningRule.addListener(this);
  522.  
  523.             learningRule.setLearningRate(lr);
  524.             learningRule.setMaxError(0.07);
  525.             learningRule.setMomentum(0.5);
  526.             learningRule.setMaxIterations(1000);
  527.  
  528.             mlp.learn(trainSet);
  529.  
  530.             numOfTrainings++;
  531.             numOfIterations += learningRule.getCurrentIteration();
  532.  
  533.             evaluate(mlp, testSet);
  534.         }
  535.  
  536.         System.out.println("Srednja vrednost broja iteracija je: " + (double) numOfIterations / numOfTrainings);
  537.  
  538.         return mlp;
  539.  
  540.     }
  541.  
  542.     @Override
  543.     public void evaluate(MultiLayerPerceptron mlp, DataSet ds) {
  544.  
  545.         // ova matrica ima 2 classLabela jer ima 1 output
  546.         // matrica ne sme da bude 1x1
  547.         String[] classLabels = new String[]{"c1", "c2"};
  548.         ConfusionMatrix cm = new ConfusionMatrix(classLabels);
  549.         double accuracy = 0;
  550.  
  551.         for (DataSetRow dataSetRow : ds) {
  552.             mlp.setInput(dataSetRow.getInput());
  553.             mlp.calculate();
  554.  
  555.             int actual = (int) Math.round(dataSetRow.getDesiredOutput()[0]);
  556.             int predicted = (int) Math.round(mlp.getOutput()[0]);
  557.  
  558.             System.out.println("Actual: " + dataSetRow.getDesiredOutput()[0]
  559.                     + "\t Predicted: " + mlp.getOutput()[0]);
  560.  
  561.             cm.incrementElement(actual, predicted);
  562.         }
  563.  
  564.         accuracy = (double) (cm.getTruePositive(0) + cm.getTrueNegative(0)) / cm.getTotal();
  565.  
  566.         System.out.println(cm.toString());
  567.  
  568.         System.out.println("Moj accuracy: " + accuracy);
  569.        
  570.         mapa.put(accuracy,mlp);
  571.  
  572.         //Training t = new Training(mlp, accuracy);
  573.         //trainings.add(t);
  574.     }
  575.  
  576.     @Override
  577.     public void saveBestNetwork() {
  578.        
  579.         Double maxKey = Collections.max(mapa.keySet());
  580.        
  581.         MultiLayerPerceptron bestModel = mapa.get(maxKey);
  582.  
  583.         bestModel.save("nn.nnet");
  584.         System.out.println("BestNN::accuracy::"+maxKey);
  585.         System.out.println("Mreza sa najvecom tacnoscu je serijalizovana!");
  586.     }
  587.  
  588.     public static void main(String[] args) {
  589.         new MojZadatak().run();
  590.     }
  591.  
  592.     @Override
  593.     public void handleLearningEvent(LearningEvent le) {
  594.         MomentumBackpropagation bp = (MomentumBackpropagation) le.getSource();
  595.         System.out.println("Iteration: " + bp.getCurrentIteration()
  596.                 + " Total network error: " + bp.getTotalNetworkError());
  597.     }
  598.  
  599.     private int getMaxIndex(double[] output) {
  600.         int max = 0;
  601.         for (int i = 1; i < output.length; i++) {
  602.             if (output[max] < output[i]) {
  603.                 max = i;
  604.             }
  605.         }
  606.         return max;
  607.     }
  608. }
Advertisement
Add Comment
Please, Sign In to add comment