Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package test.example.two;
- import java.math.BigDecimal;
- import java.util.Random;
- public class KohonenSOM2_BD
- {
- private static final int MAX_CLUSTERS = 5;
- private static final int VEC_LEN = 7;
- private static final int INPUT_PATTERNS = 7;
- private static final int INPUT_TESTS = 6;
- private static final double DECAY_RATE = 0.99; //About 100 iterations.
- private static final double MIN_ALPHA = 0.01;
- private static final double RADIUS_REDUCTION_POINT = 0.023; //Last 20% of iterations.
- private static double alpha = 1;
- private static BigDecimal d[] = new BigDecimal[MAX_CLUSTERS];
- private static double d_d[] = new double[MAX_CLUSTERS];
- //Weight matrix with randomly chosen values between 0.0 and 1.0
- private static BigDecimal[][] w;
- private static double[][] w_d;
- private static BigDecimal pattern[][];
- private static int pattern_d[][];
- private static BigDecimal tests[][];
- private static int tests_d[][];
- private static void gen(){
- w = new BigDecimal[MAX_CLUSTERS][VEC_LEN];
- w_d = new double[MAX_CLUSTERS][VEC_LEN];
- pattern = new BigDecimal[INPUT_PATTERNS][VEC_LEN];
- pattern_d = new int[INPUT_PATTERNS][VEC_LEN];
- tests = new BigDecimal[INPUT_TESTS][VEC_LEN];
- tests_d = new int[INPUT_TESTS][VEC_LEN];
- Random randy = new Random();
- for(int i=0;i<w.length;i++){
- for(int j=0;j<w[0].length;j++){
- w_d[i][j] = randy.nextDouble();
- w[i][j] = new BigDecimal(Double.toString(w_d[i][j]));
- }
- }
- for(int i=0;i<pattern.length;i++){
- for(int j=0;j<pattern[0].length;j++){
- pattern_d[i][j] = randy.nextInt(2);
- pattern[i][j] = new BigDecimal(Integer.toString(pattern_d[i][j]));
- }
- }
- for(int i=0;i<tests.length;i++){
- for(int j=0;j<tests[0].length;j++){
- tests_d[i][j] = randy.nextInt(2);
- tests[i][j] = new BigDecimal(Integer.toString(tests_d[i][j]));
- }
- }
- }
- private static void training()
- {
- int iterations = 0;
- boolean reductionFlag = false;
- int reductionPoint = 0;
- int dMin = 0,dMin_d=0;
- while(alpha > MIN_ALPHA)
- {
- iterations += 1;
- for(int vecNum = 0; vecNum <= (INPUT_PATTERNS - 1); vecNum++)
- {
- //Compute input for all nodes.
- computeInput(pattern,pattern_d, vecNum);
- //See which is smaller?
- dMin = minimum(d);
- dMin_d = minimum(d_d);
- //Update the weights on the winning unit.
- updateWeights(vecNum, dMin, dMin_d);
- } // VecNum
- //Reduce the learning rate.
- alpha = DECAY_RATE * alpha;
- //Reduce radius at specified point.
- if(alpha < RADIUS_REDUCTION_POINT){
- if(reductionFlag == false){
- reductionFlag = true;
- reductionPoint = iterations;
- }
- }
- }
- System.out.println("Iterations: " + iterations);
- System.out.println("Neighborhood radius reduced after " + reductionPoint + " iterations.");
- return;
- }
- private static void computeInput(BigDecimal[][] vectorArray,int[][] vectorArray_d, int vectorNumber)
- {
- clearArray(d,d_d);
- for(int i = 0; i <= (MAX_CLUSTERS - 1); i++){
- for(int j = 0; j <= (VEC_LEN - 1); j++){
- d_d[i] += Math.pow((w_d[i][j] - vectorArray_d[vectorNumber][j]), 2);
- d[i] = d[i].add(w[i][j].add(vectorArray[i][j].negate()).pow(2));
- //d[i] = d[i].pow(2);
- } // j
- } // i
- return;
- }
- private static void updateWeights(int vectorNumber, int dMin, int dMin_d)
- {
- for(int i = 0; i <= (VEC_LEN - 1); i++)
- {
- //Update the winner.
- w_d[dMin_d][i] = w_d[dMin_d][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d][i]));
- w[dMin][i] = w[dMin][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin][i].negate())));
- w[dMin][i] = w[dMin][i].setScale(100, BigDecimal.ROUND_DOWN);
- //Only include neighbors before radius reduction point is reached.
- if(alpha > RADIUS_REDUCTION_POINT){
- if((dMin > 0) && (dMin < (MAX_CLUSTERS - 1))){
- //Update neighbor to the left...
- w[dMin-1][i] = w[dMin-1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin-1][i].negate())));
- w[dMin-1][i] = w[dMin-1][i].setScale(100, BigDecimal.ROUND_DOWN);
- //and update neighbor to the right.
- w[dMin+1][i] = w[dMin+1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin+1][i].negate())));
- w[dMin+1][i] = w[dMin+1][i].setScale(100, BigDecimal.ROUND_DOWN);
- } else {
- if(dMin == 0){
- //Update neighbor to the right.
- w[dMin+1][i] = w[dMin+1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin+1][i].negate())));
- w[dMin+1][i] = w[dMin+1][i].setScale(100, BigDecimal.ROUND_DOWN);
- } else {
- //Update neighbor to the left.
- w[dMin-1][i] = w[dMin-1][i].add(new BigDecimal(Double.toString(alpha)).multiply(pattern[vectorNumber][i].add(w[dMin-1][i].negate())));
- w[dMin-1][i] = w[dMin-1][i].setScale(100, BigDecimal.ROUND_DOWN);
- }
- }
- if((dMin_d > 0) && (dMin_d < (MAX_CLUSTERS - 1))){
- //Update neighbor to the left...
- w_d[dMin_d - 1][i] = w_d[dMin_d - 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d - 1][i]));
- //and update neighbor to the right.
- w_d[dMin_d + 1][i] = w_d[dMin_d + 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d + 1][i]));
- } else {
- if(dMin_d == 0){
- //Update neighbor to the right.
- w_d[dMin_d + 1][i] = w_d[dMin_d + 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d + 1][i]));
- } else {
- //Update neighbor to the left.
- w_d[dMin_d - 1][i] = w_d[dMin_d - 1][i] + (alpha * (pattern_d[vectorNumber][i] - w_d[dMin_d - 1][i]));
- }
- }
- }
- } // i
- return;
- }
- private static void clearArray(BigDecimal[] d, double[] d_d)
- {
- for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
- {
- d[i] = new BigDecimal("0");
- d_d[i] =0;
- } // i
- return;
- }
- private static int minimum(BigDecimal[] nodeArray)
- {
- int winner = 0;
- boolean foundNewWinner = false;
- boolean done = false;
- while(!done)
- {
- foundNewWinner = false;
- for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
- {
- if(i != winner){ //Avoid self-comparison.
- if(nodeArray[i].compareTo(nodeArray[winner]) == -1){
- winner = i;
- foundNewWinner = true;
- }
- }
- } // i
- if(foundNewWinner == false){
- done = true;
- }
- }
- return winner;
- }
- private static int minimum(double[] nodeArray)
- {
- int winner = 0;
- boolean foundNewWinner = false;
- boolean done = false;
- while(!done)
- {
- foundNewWinner = false;
- for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
- {
- if(i != winner){ //Avoid self-comparison.
- if(nodeArray[i]<nodeArray[winner]){
- winner = i;
- foundNewWinner = true;
- }
- }
- } // i
- if(foundNewWinner == false){
- done = true;
- }
- }
- return winner;
- }
- private static void printResults()
- {
- int dMin = 0, dMin_d=0;
- //Print clusters created.
- System.out.println("Clusters for training input:");
- for(int vecNum = 0; vecNum <= (INPUT_PATTERNS - 1); vecNum++)
- {
- //Compute input.
- computeInput(pattern,pattern_d, vecNum);
- //See which is smaller.
- dMin = minimum(d);
- dMin_d = minimum(d_d);
- System.out.print("Vector (");
- for(int i = 0; i <= (VEC_LEN - 1); i++)
- {
- System.out.print(pattern[vecNum][i] + ", ");
- } // i
- System.out.print(") fits into category " + dMin + " ("+dMin_d+") \n");
- } // VecNum
- //Print weight matrix.
- System.out.println("------------------------------------------------------------------------");
- for(int i = 0; i <= (MAX_CLUSTERS - 1); i++)
- {
- System.out.println("Weights for Node " + i + " connections:");
- System.out.print(" ");
- for(int j = 0; j <= (VEC_LEN - 1); j++)
- {
- String temp = String.format("%.3f", w[i][j]);
- System.out.print(temp + ", ");
- } // j
- System.out.print("\n");
- } // i
- //Print post-training tests.
- System.out.println("------------------------------------------------------------------------");
- System.out.println("Categorized test input:");
- for(int vecNum = 0; vecNum <= (INPUT_TESTS - 1); vecNum++)
- {
- //Compute input for all nodes.
- computeInput(tests,tests_d, vecNum);
- //See which is smaller.
- dMin = minimum(d);
- dMin_d = minimum(d_d);
- System.out.print("Vector (");
- for(int i = 0; i <= (VEC_LEN - 1); i++)
- {
- System.out.print(tests[vecNum][i] + ", ");
- } // i
- System.out.print(") fits into category " + dMin + " ("+dMin_d+") \n");
- } // VecNum
- return;
- }
- public static void main(String[] args)
- {
- gen();
- training();
- printResults();
- return;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement