Advertisement
Guest User

Kohonen SOM

a guest
Feb 6th, 2016
188
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 9.78 KB | None | 0 0
  1. package test.example.two;
  2.  
  3. import java.math.BigDecimal;
  4. import java.util.Random;
  5.  
  6. public class KohonenSOM2_BD
  7. {
  8.     private static final int MAX_CLUSTERS = 5;
  9.     private static final int VEC_LEN = 7;
  10.     private static final int INPUT_PATTERNS = 7;
  11.     private static final int INPUT_TESTS = 6;
  12.     private static final double DECAY_RATE = 0.99;                  //About 100 iterations.
  13.     private static final double MIN_ALPHA = 0.01;
  14.     private static final double RADIUS_REDUCTION_POINT = 0.023;     //Last 20% of iterations.
  15.    
  16.     private static double alpha = 1;
  17.     private static BigDecimal d[] = new BigDecimal[MAX_CLUSTERS];
  18.     private static double d_d[] = new double[MAX_CLUSTERS];
  19.    
  20.     //Weight matrix with randomly chosen values between 0.0 and 1.0
  21.     private static BigDecimal[][] w;
  22.     private static double[][] w_d;
  23.    
  24.     private static BigDecimal pattern[][];
  25.     private static int pattern_d[][];
  26.    
  27.     private static BigDecimal tests[][];
  28.     private static int tests_d[][];
  29.    
  30.     private static void gen(){
  31.         w = new BigDecimal[MAX_CLUSTERS][VEC_LEN];
  32.         w_d = new double[MAX_CLUSTERS][VEC_LEN];
  33.         pattern = new BigDecimal[INPUT_PATTERNS][VEC_LEN];
  34.         pattern_d = new int[INPUT_PATTERNS][VEC_LEN];
  35.         tests = new BigDecimal[INPUT_TESTS][VEC_LEN];
  36.         tests_d = new int[INPUT_TESTS][VEC_LEN];
  37.         Random randy = new Random();
  38.         for(int i=0;i<w.length;i++){
  39.             for(int j=0;j<w[0].length;j++){
  40.                 w_d[i][j] = randy.nextDouble();
  41.                 w[i][j] = new BigDecimal(Double.toString(w_d[i][j]));
  42.             }
  43.         }
  44.         for(int i=0;i<pattern.length;i++){
  45.             for(int j=0;j<pattern[0].length;j++){
  46.                 pattern_d[i][j] = randy.nextInt(2);
  47.                 pattern[i][j] = new BigDecimal(Integer.toString(pattern_d[i][j]));
  48.             }
  49.         }
  50.         for(int i=0;i<tests.length;i++){
  51.             for(int j=0;j<tests[0].length;j++){
  52.                 tests_d[i][j] = randy.nextInt(2);
  53.                 tests[i][j] = new BigDecimal(Integer.toString(tests_d[i][j]));
  54.             }
  55.         }
  56.     }
  57.     private static void training()
  58.     {
  59.         int iterations = 0;
  60.         boolean reductionFlag = false;
  61.         int reductionPoint = 0;
  62.         int dMin = 0,dMin_d=0;
  63.  
  64.         while(alpha > MIN_ALPHA)
  65.         {
  66.             iterations += 1;
  67.  
  68.             for(int vecNum = 0; vecNum <= (INPUT_PATTERNS - 1); vecNum++)
  69.             {
  70.                 //Compute input for all nodes.
  71.                 computeInput(pattern,pattern_d, vecNum);
  72.  
  73.                 //See which is smaller?
  74.                 dMin = minimum(d);
  75.                 dMin_d = minimum(d_d);
  76.  
  77.                 //Update the weights on the winning unit.
  78.                 updateWeights(vecNum, dMin, dMin_d);
  79.  
  80.             } // VecNum
  81.  
  82.             //Reduce the learning rate.
  83.             alpha = DECAY_RATE * alpha;
  84.  
  85.             //Reduce radius at specified point.
  86.             if(alpha < RADIUS_REDUCTION_POINT){
  87.                 if(reductionFlag == false){
  88.                     reductionFlag = true;
  89.                     reductionPoint = iterations;
  90.                 }
  91.             }
  92.         }
  93.  
  94.         System.out.println("Iterations: " + iterations);
  95.        
  96.         System.out.println("Neighborhood radius reduced after " + reductionPoint + " iterations.");
  97.        
  98.         return;
  99.     }
  100.    
  101.     private static void computeInput(BigDecimal[][] vectorArray,int[][] vectorArray_d, int vectorNumber)
  102.     {
  103.         clearArray(d,d_d);
  104.  
  105.         for(int i = 0; i <= (MAX_CLUSTERS - 1); i++){
  106.             for(int j = 0; j <= (VEC_LEN - 1); j++){
  107.                 d_d[i] += Math.pow((w_d[i][j] - vectorArray_d[vectorNumber][j]), 2);
  108.                 d[i] = d[i].add(w[i][j].add(vectorArray[i][j].negate()).pow(2));
  109.                 //d[i] = d[i].pow(2);
  110.             } // j
  111.         } // i
  112.         return;
  113.     }
  114.    
  115.     private static void updateWeights(int vectorNumber, int dMin, int dMin_d)
  116.     {
  117.         for(int i = 0; i <= (VEC_LEN - 1); i++)
  118.         {
  119.             //Update the winner.
  120.             w_d[dMin_d][i] = w_d[dMin_d][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d][i]));
  121.             w[dMin][i] = w[dMin][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin][i].negate())));
  122.             w[dMin][i] = w[dMin][i].setScale(100, BigDecimal.ROUND_DOWN);
  123.  
  124.             //Only include neighbors before radius reduction point is reached.
  125.             if(alpha > RADIUS_REDUCTION_POINT){
  126.                 if((dMin > 0) && (dMin < (MAX_CLUSTERS - 1))){
  127.                     //Update neighbor to the left...
  128.                     w[dMin-1][i] = w[dMin-1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin-1][i].negate())));
  129.                     w[dMin-1][i] = w[dMin-1][i].setScale(100, BigDecimal.ROUND_DOWN);
  130.                     //and update neighbor to the right.
  131.                     w[dMin+1][i] = w[dMin+1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin+1][i].negate())));
  132.                     w[dMin+1][i] = w[dMin+1][i].setScale(100, BigDecimal.ROUND_DOWN);
  133.                 } else {
  134.                     if(dMin == 0){
  135.                         //Update neighbor to the right.
  136.                         w[dMin+1][i] = w[dMin+1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin+1][i].negate())));
  137.                         w[dMin+1][i] = w[dMin+1][i].setScale(100, BigDecimal.ROUND_DOWN);
  138.                     } else {
  139.                         //Update neighbor to the left.
  140.                         w[dMin-1][i] = w[dMin-1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin-1][i].negate())));
  141.                         w[dMin-1][i] = w[dMin-1][i].setScale(100, BigDecimal.ROUND_DOWN);
  142.                     }
  143.                 }
  144.                 if((dMin_d > 0) && (dMin_d < (MAX_CLUSTERS - 1))){
  145.                     //Update neighbor to the left...
  146.                     w_d[dMin_d - 1][i] = w_d[dMin_d - 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d - 1][i]));
  147.                     //and update neighbor to the right.
  148.                     w_d[dMin_d + 1][i] = w_d[dMin_d + 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d + 1][i]));
  149.                 } else {
  150.                     if(dMin_d == 0){
  151.                         //Update neighbor to the right.
  152.                         w_d[dMin_d + 1][i] = w_d[dMin_d + 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d + 1][i]));
  153.                     } else {
  154.                         //Update neighbor to the left.
  155.                         w_d[dMin_d - 1][i] = w_d[dMin_d - 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d - 1][i]));
  156.                     }
  157.                 }
  158.             }
  159.            
  160.         } // i
  161.         return;
  162.     }
  163.    
  164.     private static void clearArray(BigDecimal[] d, double[] d_d)
  165.     {
  166.         for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
  167.         {
  168.             d[i] = new BigDecimal("0");
  169.             d_d[i] =0;
  170.         } // i
  171.         return;
  172.     }
  173.    
  174.     private static int minimum(BigDecimal[] nodeArray)
  175.     {
  176.         int winner = 0;
  177.         boolean foundNewWinner = false;
  178.         boolean done = false;
  179.  
  180.         while(!done)
  181.         {
  182.             foundNewWinner = false;
  183.             for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
  184.             {
  185.                 if(i != winner){             //Avoid self-comparison.
  186.                     if(nodeArray[i].compareTo(nodeArray[winner]) == -1){
  187.                         winner = i;
  188.                         foundNewWinner = true;
  189.                     }
  190.                 }
  191.             } // i
  192.  
  193.             if(foundNewWinner == false){
  194.                 done = true;
  195.             }
  196.         }
  197.         return winner;
  198.     }
  199.     private static int minimum(double[] nodeArray)
  200.     {
  201.         int winner = 0;
  202.         boolean foundNewWinner = false;
  203.         boolean done = false;
  204.  
  205.         while(!done)
  206.         {
  207.             foundNewWinner = false;
  208.             for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
  209.             {
  210.                 if(i != winner){             //Avoid self-comparison.
  211.                     if(nodeArray[i]<nodeArray[winner]){
  212.                         winner = i;
  213.                         foundNewWinner = true;
  214.                     }
  215.                 }
  216.             } // i
  217.  
  218.             if(foundNewWinner == false){
  219.                 done = true;
  220.             }
  221.         }
  222.         return winner;
  223.     }
  224.    
  225.     private static void printResults()
  226.     {
  227.         int dMin = 0, dMin_d=0;
  228.  
  229.         //Print clusters created.
  230.             System.out.println("Clusters for training input:");
  231.             for(int vecNum = 0; vecNum <= (INPUT_PATTERNS - 1); vecNum++)
  232.             {
  233.                 //Compute input.
  234.                 computeInput(pattern,pattern_d, vecNum);
  235.  
  236.                 //See which is smaller.
  237.                 dMin = minimum(d);
  238.                 dMin_d = minimum(d_d);
  239.  
  240.                 System.out.print("Vector (");
  241.                 for(int i = 0; i <= (VEC_LEN - 1); i++)
  242.                 {
  243.                     System.out.print(pattern[vecNum][i] + ", ");
  244.                 } // i
  245.                 System.out.print(") fits into category " + dMin + " ("+dMin_d+") \n");
  246.  
  247.             } // VecNum
  248.  
  249.         //Print weight matrix.
  250.             System.out.println("------------------------------------------------------------------------");
  251.             for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
  252.             {
  253.                 System.out.println("Weights for Node " + i + " connections:");
  254.                 System.out.print("     ");
  255.                 for(int j = 0; j <= (VEC_LEN - 1); j++)
  256.                 {
  257.                     String temp = String.format("%.3f", w[i][j]);
  258.                     System.out.print(temp + ", ");
  259.                 } // j
  260.                 System.out.print("\n");
  261.             } // i
  262.  
  263.         //Print post-training tests.
  264.             System.out.println("------------------------------------------------------------------------");
  265.             System.out.println("Categorized test input:");
  266.             for(int vecNum = 0; vecNum <= (INPUT_TESTS - 1); vecNum++)
  267.             {
  268.                 //Compute input for all nodes.
  269.                 computeInput(tests,tests_d, vecNum);
  270.  
  271.                 //See which is smaller.
  272.                 dMin = minimum(d);
  273.                 dMin_d = minimum(d_d);
  274.  
  275.                 System.out.print("Vector (");
  276.                 for(int i = 0; i <= (VEC_LEN - 1); i++)
  277.                 {
  278.                     System.out.print(tests[vecNum][i] + ", ");
  279.                 } // i
  280.                 System.out.print(") fits into category " + dMin + " ("+dMin_d+") \n");
  281.  
  282.             } // VecNum
  283.         return;
  284.     }
  285.    
  286.     public static void main(String[] args)
  287.     {
  288.         gen();
  289.         training();
  290.         printResults();
  291.         return;
  292.     }
  293.  
  294. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement