AleksaLjujic

IS NN Kod

Jul 26th, 2025
904
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 6.12 KB | None | 0 0
  1.  * GLAVNI NN KOD
  2.  */
  3. public class NNVezba implements LearningEventListener{
  4.     int inputCount = 13;
  5.     int outputCount = 3;
  6.     double[] lr = {0.2, 0.4, 0.6};
  7.     ArrayList<Training> trainings = new ArrayList<>();
  8.     /**
  9.      * @param args the command line arguments
  10.      */
  11.     public static void main(String[] args) {
  12.         // TODO code application logic here
  13.        
  14.         (new NNVezba()).run();
  15.     }
  16.  
  17.     @Override
  18.     public void handleLearningEvent(LearningEvent event) {
  19.         BackPropagation bp = (BackPropagation) event.getSource();
  20.         System.out.println("iteration: " +bp.getCurrentIteration()+" "
  21.                            + "Total network error: " +bp.getTotalNetworkError());
  22.     }
  23.  
  24.     private void run() {
  25.         String filePath = "wines.csv";
  26.         DataSet dataSet = DataSet.createFromFile(filePath, inputCount, outputCount, ",");
  27.        
  28.         Normalizer norm = new MaxNormalizer(dataSet);
  29.         norm.normalize(dataSet);
  30.         dataSet.shuffle();
  31.        
  32.         DataSet[] trainTest = dataSet.split(0.7, 0.3);
  33.         DataSet trainSet = trainTest[0];
  34.         DataSet testSet = trainTest[1];
  35.        
  36.         int numOfIterations = 0;
  37.         int numOfTrainings = 0;
  38.        
  39.         for (double l : lr) {
  40.             MultiLayerPerceptron nn = new MultiLayerPerceptron(inputCount, 22, outputCount);
  41.            
  42.             BackPropagation learningRule = nn.getLearningRule();
  43.            
  44.             learningRule.addListener(this);
  45.            
  46.             learningRule.setMaxError(0.02);
  47.             learningRule.setLearningRate(l);
  48.             learningRule.setMaxIterations(1000);
  49.            
  50.             nn.learn(trainSet);
  51.                    
  52.             numOfTrainings++;
  53.             numOfIterations += learningRule.getCurrentIteration();
  54.            
  55.             double accuracy = evaluateAccuracy(nn, testSet);
  56.             Training t = new Training(nn, accuracy);
  57.             trainings.add(t);
  58.         }
  59.         System.out.println("Srednja vrednosti broja iteracija :" +numOfTrainings/numOfIterations);
  60.         SaveNetWithMaxAccuracy();
  61.        
  62.     }
  63.  
  64.     private double evaluateAccuracy(MultiLayerPerceptron nn, DataSet testSet) {
  65.         ConfMatrix cmatrix = new ConfMatrix(outputCount);
  66.         double accuracy = 0;
  67.        
  68.         for (DataSetRow dataSetRow : testSet) {
  69.             nn.setInput(dataSetRow.getInput());
  70.             nn.calculate();
  71.            
  72.             int actual = getMaxIndex(dataSetRow.getDesiredOutput());
  73.             int predicted = getMaxIndex(nn.getOutput());
  74.            
  75.             cmatrix.incrementElement(actual, predicted);
  76.                
  77.         }
  78.        
  79.         for (int i = 0; i < outputCount; i++) {
  80.             accuracy += (double) (cmatrix.getTruePositive(i) + cmatrix.getTrueNegative(i)) / cmatrix.total;
  81.         }
  82.        
  83.         cmatrix.print();
  84.        
  85.         System.out.println("accuracy : "+(double) accuracy/outputCount);
  86.        
  87.         return (double) accuracy / outputCount;
  88.     }
  89.  
  90.     private void SaveNetWithMaxAccuracy() {
  91.         Training maxTr = trainings.get(0);
  92.         for (Training t : trainings) {
  93.             if(t.getAccuracy() > maxTr.getAccuracy()){
  94.                  maxTr = t;
  95.             }  
  96.         }
  97.         maxTr.getNn().save("nn.nnet");
  98.     }
  99.  
  100.     private int getMaxIndex(double[] output) {
  101.         int maxIndex = 0;
  102.        
  103.         for (int i = 0; i < output.length; i++) {
  104.             if(output[maxIndex] < output[i]){
  105.                 maxIndex = i;
  106.             }
  107.         }
  108.        
  109.         return maxIndex;
  110.    }
  111.    
  112. }
  113.  
  114. ________________________________________________________________________________________________________________
  115.  
  116. /**
  117.  *
  118.  * Training klasa
  119.  */
  120.    
  121. package nn.vezba;
  122.  
  123. import org.neuroph.core.NeuralNetwork;
  124.  
  125. public class Training {
  126.     private NeuralNetwork nn;
  127.     private double accuracy;
  128.  
  129.     public Training(NeuralNetwork nn, double accuracy) {
  130.         this.nn = nn;
  131.         this.accuracy = accuracy;
  132.     }
  133.  
  134.     public double getAccuracy() {
  135.         return accuracy;
  136.     }
  137.  
  138.     public void setAccuracy(double accuracy) {
  139.         this.accuracy = accuracy;
  140.     }
  141.  
  142.     public NeuralNetwork getNn() {
  143.         return nn;
  144.     }
  145.  
  146.     public void setNn(NeuralNetwork nn) {
  147.         this.nn = nn;
  148.     }
  149.    
  150. }
  151.  
  152. ________________________________________________________________________________________________________________
  153.  
  154. /**
  155.  *
  156.  * ConfMatrix klasa  (vecinom se kopira iz ConfusionMatrix klase)
  157.  */
  158.  
  159. public class ConfMatrix {
  160.     int[][] matrix;    
  161.     int classCount;
  162.     int total = 0;
  163.  
  164.     public ConfMatrix(int classCount) {
  165.         this.matrix = new int[classCount][classCount];
  166.         this.classCount = classCount;
  167.     }
  168.    
  169.     public void incrementElement(int actual, int predicted) {
  170.         matrix[actual][predicted]++;
  171.         total++;
  172.     }
  173.    
  174.      public int getTruePositive(int cl) {
  175.         return (int)matrix[cl][cl];
  176.     }
  177.    
  178.     public int getTrueNegative(int cl) {
  179.         int trueNegative = 0;
  180.        
  181.         for(int i = 0; i < classCount; i++) {
  182.             if (i == cl) continue;
  183.             for(int j = 0; j < classCount; j++) {
  184.                 if (j == cl) continue;
  185.                 trueNegative += matrix[i][j];
  186.             }
  187.         }
  188.        
  189.         return trueNegative;
  190.     }    
  191.  
  192.     public int getFalsePositive(int cl) {
  193.         int falsePositive = 0;
  194.        
  195.         for(int i=0; i<classCount; i++) {
  196.             if (i == cl) continue;
  197.             falsePositive += matrix[i][cl];
  198.         }
  199.        
  200.         return falsePositive;
  201.     }
  202.  
  203.     public int getFalseNegative(int cl) {
  204.         int falseNegative = 0;
  205.        
  206.         for(int i=0; i<classCount; i++) {
  207.             if (i == cl) continue;
  208.             falseNegative += matrix[cl][i];
  209.         }
  210.        
  211.         return falseNegative;
  212.     }
  213.     public void print(){
  214.         for (int i = 0; i < matrix.length; i++) {  
  215.             for (int j = 0; j < matrix.length; j++) {
  216.                 System.out.print(matrix[i][j] + " ");
  217.             }
  218.             System.out.println();
  219.         }
  220.     }
  221. }
Advertisement
Add Comment
Please, Sign In to add comment