Advertisement
Guest User

EM Algorithm

a guest
Sep 28th, 2011
636
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 6.42 KB | None | 0 0
  1. /*
  2.  * Kocaeli University Computer Engineering Department 2011-Fall
  3.  *
  4.  * Class     : Data Mining
  5.  * Subject   : Inspection and Implementation of Model Based Clustering Method
  6.  *             Expectation-Maximization (EM)
  7.  *
  8.  * Student   : Mustafa KIYAR
  9.  * Number    : 105112014
  10.  *
  11.  */
  12. package em;
  13.  
  14. import java.io.BufferedWriter;
  15. import java.io.FileInputStream;
  16. import java.io.FileWriter;
  17. import java.io.IOException;
  18. import java.util.Properties;
  19. import org.apache.commons.lang3.ArrayUtils;
  20.  
  21. /**
  22.  *
  23.  * @author Mustafa KIYAR
  24.  */
  25. public class EM {
  26.  
  27.     public static double findProbablity(double x, double mu, double sigma) {
  28.         double dExp = Math.exp(-Math.pow(x - mu, 2) / (2 * Math.pow(sigma, 2)));
  29.         return dExp / (Math.sqrt(2 * Math.PI) * sigma);
  30.     }
  31.  
  32.     public static double findMean(double dArr[]) {
  33.         double mu = 0;
  34.         for (int i = 0; i < dArr.length; i++) {
  35.             mu += dArr[i];
  36.         }
  37.         return mu / dArr.length;
  38.     }
  39.  
  40.     public static double findStandartDeviation(double dArr[], double mu) {
  41.         double sigma = 0;
  42.         for (int i = 0; i < dArr.length; i++) {
  43.             sigma += Math.pow(dArr[i] - mu, 2);
  44.         }
  45.         return Math.sqrt(sigma / dArr.length);
  46.     }
  47.  
  48.     public static void myLog(String str) {
  49.         System.out.println(str);
  50.     }
  51.  
  52.     public static void analyzeArr(double dArr[]) {
  53.         double mu, sigma;
  54.  
  55.         mu = findMean(dArr);
  56.         sigma = findStandartDeviation(dArr, mu);
  57.  
  58.         myLog("mean               : " + mu);
  59.         myLog("standart deviation : " + sigma);
  60.  
  61.         for (int i = 0; i < dArr.length; i++) {
  62.             myLog("probablity of " + dArr[i] + " is "
  63.                     + findProbablity(dArr[i], mu, sigma));
  64.         }
  65.     }
  66.  
  67.     public static double findProbablity(double x, double dArr[]) {
  68.         double mu, sigma, probablity;
  69.  
  70.         mu = findMean(dArr);
  71.         sigma = findStandartDeviation(dArr, mu);
  72.         probablity = findProbablity(x, mu, sigma);
  73.  
  74.         myLog("mean                 : " + mu);
  75.         myLog("standart deviation   : " + sigma);
  76.         myLog("probablity of object : " + probablity);
  77.         return probablity;
  78.     }
  79.  
  80.     public static int findRelation(double x, double dArr[][]) {
  81.         double nominator = 0, denominator = 0, probablity = 0;
  82.         int iBestCluster = 0;
  83.         for (int i = 0; i < dArr.length; i++) {
  84.             myLog("");
  85.             myLog("Inspection of " + x + " in Cluster :" + i);
  86.             probablity = findProbablity(x, dArr[i]);
  87.             denominator += probablity;
  88.             if (nominator < probablity) {
  89.                 nominator = probablity;
  90.                 iBestCluster = i;
  91.             }
  92.         }
  93.         return iBestCluster;
  94.     }
  95.  
  96.     public static void switchCluster(int iCluster, int iPos, int iBestCluster,
  97.             double dArr[][]) {
  98.  
  99.         double dVal = dArr[iCluster][iPos];
  100.         dArr[iCluster] = ArrayUtils.remove(dArr[iCluster], iPos);
  101.         dArr[iBestCluster] = ArrayUtils.add(dArr[iBestCluster], dVal);
  102.  
  103.         myLog("Object           : " + dVal);
  104.         myLog("Original cluster : " + iCluster);
  105.         myLog("Moved to cluster : " + iBestCluster);
  106.  
  107.     }
  108.  
  109.     public static void runProgram(double dArr[][]) {
  110.         int iBestCluster;
  111.         boolean check;
  112.         do {
  113.             check = false;
  114.             for (int i = 0; i < dArr.length; i++) {
  115.                 myLog("##########################################################");
  116.                 myLog("Scaning Cluster : " + i);
  117.                 for (int j = 0; j < dArr[i].length; j++) {
  118.                     iBestCluster = findRelation(dArr[i][j], dArr);
  119.                     if (i != iBestCluster) {
  120.                         switchCluster(i, j, iBestCluster, dArr);
  121.                         j--;
  122.                         check = true;
  123.                     } else {
  124.                         myLog("Object belongs to orginal cluster.");
  125.                     }
  126.                     myLog("......................................................");
  127.                 }
  128.             }
  129.         } while (check);
  130.     }
  131.  
  132.     /**
  133.      * @param args the command line arguments
  134.      */
  135.     public static void main(String[] args) {
  136.  
  137.         String sArrInput[] = {""};
  138.         int numberOfCluster = 0, iSizeCuster = 0;
  139.         Properties property = new Properties();
  140.         try {
  141.             property.load(new FileInputStream(args[0]));
  142.             sArrInput = property.getProperty("data").split(",");
  143.             numberOfCluster = Integer.parseInt(property.getProperty("numberofclusters"));
  144.  
  145.         } catch (IOException ex) {
  146.             ex.printStackTrace();
  147.         }
  148.  
  149.         if (sArrInput.length < 3 || sArrInput.length > 1000
  150.                 || sArrInput.length <= (numberOfCluster * 2)) {
  151.             myLog("Input data is irrelevant!");
  152.             return;
  153.         }
  154.  
  155.         /* find size of an cluster */
  156.         iSizeCuster = sArrInput.length / numberOfCluster;
  157.  
  158.         double dArr[][] = new double[numberOfCluster][iSizeCuster];
  159.         double dCluster[] = new double[iSizeCuster];
  160.         for (int i = 0; i < sArrInput.length; i++) {
  161.             dCluster[i % iSizeCuster] = Double.parseDouble(sArrInput[i]);
  162.             if (i != 0 && (i + 1) % iSizeCuster == 0) {
  163.                 dArr[i / iSizeCuster] = dCluster;
  164.                 dCluster = new double[iSizeCuster];
  165.             }
  166.         }
  167.         if (sArrInput.length % iSizeCuster != 0) {
  168.             dCluster = ArrayUtils.subarray(dCluster, 0, (sArrInput.length % iSizeCuster));
  169.             dArr[numberOfCluster - 1] = ArrayUtils.addAll(dArr[numberOfCluster - 1], dCluster);
  170.         }
  171.  
  172.         runProgram(dArr);
  173.         try {
  174.  
  175.             FileWriter fw = new FileWriter("out.log");
  176.             BufferedWriter bw = new BufferedWriter(fw);
  177.  
  178.             double dStdDev = 0.0, dMean = 0.0;
  179.  
  180.             for (int i = 0; i < dArr.length; i++) {
  181.                 dMean = findMean(dArr[i]);
  182.                 dStdDev = findStandartDeviation(dArr[i], dMean);
  183.                 bw.write("Cluster " + i + " : " + ArrayUtils.toString(dArr[i]) + "\n");
  184.                 bw.write("Mean of Cluster " + i + " : " + Double.toString(dMean)+ "\n");
  185.                 bw.write("Stdev of Cluster " + i + " : " + Double.toString(dStdDev)+"\n");
  186.                 bw.write("\n");
  187.             }
  188.             bw.close();
  189.  
  190.         } catch (IOException ex) {
  191.             ex.printStackTrace();
  192.         }
  193.     }
  194. }
  195.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement