Advertisement
Michal85

EM algorithm C++

Feb 13th, 2014
286
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.57 KB | None | 0 0
  1. /**
  2.  * @brief Performs EM
  3.  * @param f function
  4.  * @param x input points
  5.  * @return critirial function value
  6.  */
  7. double ExpectationMaximization::maximize(SumGauss1D* f, const Points1D &pts) {
  8.  
  9.   // Delta value
  10.   const double DELTA = 1e-5;
  11.  
  12.   // Cache gaussians
  13.   this->gaussians = f->getGaussians();
  14.  
  15.   // Perform algorithm
  16.   double prevCrit = 0, crit = 0;
  17.   do {
  18.     Array2D<double> hs = estep(pts);
  19.     mstep(pts, hs);
  20.     prevCrit = crit;
  21.     crit = criterium(pts, hs);
  22.     QString m = QObject::tr("Checking criterium previous: %1, current %2...");
  23.     logDbg(m.arg(prevCrit).arg(crit));
  24.   } while (fabs(crit - prevCrit) > DELTA * fabs(prevCrit));
  25.  
  26.   return crit;
  27. }
  28. //---------------------------------------------------------------------------
  29.  
  30. /**
  31.  * @brief Expectation step (finds probabilities that points belong to gaussians)
  32.  * @param x points
  33.  * @return table of posterior probabilities
  34.  */
  35. Array2D<double> ExpectationMaximization::estep(const Points1D &x) {
  36.  
  37.   // Configure bayes
  38.   logDbg(QObject::tr("Performing E-step..."));
  39.   bayes->setGaussians(gaussians);
  40.  
  41.   // Calculate probabilities that i-th point belongs to j-th class
  42.   Array2D<double> hs(x.size(), gaussians.size());
  43.   for (int i = 0; i < x.size(); i++) {
  44.     for (int j = 0; j < gaussians.size(); j++) {
  45.       bayes->setIndex(j);
  46.       hs.set(i, j, bayes->eval(x.at(i)));
  47.     }
  48.   }
  49.   return hs;
  50. }
  51. //---------------------------------------------------------------------------
  52.  
  53. /**
  54.  * @brief Maximization step (updates gaussians based on probabilities)
  55.  * @param x points
  56.  * @param hs table of posterior probabilities
  57.  */
  58. void ExpectationMaximization::mstep(const Points1D &x,
  59.                                     const Array2D<double> &hs) {
  60.  
  61.   // Check dimensions
  62.   logDbg(QObject::tr("Performing M-step..."));
  63.   if (hs.colCount() != gaussians.size()) {
  64.     logErr(QObject::tr("mStep: invalid number of gaussians."));
  65.     return;
  66.   }
  67.   if (hs.rowCount() != x.size()) {
  68.     logErr(QObject::tr("mStep: invalid number of points."));
  69.     return;
  70.   }
  71.  
  72.   // Counts
  73.   int gaussCount = hs.colCount();
  74.   int pointCount = hs.rowCount();
  75.  
  76.   // Iterate through all components
  77.   for (int i = 0; i < gaussCount; i++) {
  78.  
  79.     // Current gaussian
  80.     Gauss1D* g = gaussians.at(i);
  81.  
  82.     // Sum of posterior probabilities
  83.     QVector<double> probs = hs.getCol(i);
  84.     double sum = MathUtils::sum(probs);
  85.  
  86.     // Update parameters of current gaussian
  87.     g->setHeight(1.0 / pointCount * sum);
  88.     g->setMean(MathUtils::dot(x, probs) / sum);
  89.     QVector<double> v = MathUtils::sub(x, g->getMean());
  90.     g->setSigma(sqrt(MathUtils::dot(v, v) / sum));
  91.  
  92.     // Debug
  93.     QString m = QObject::tr("Component %1 { mean: %2, sigma: %3, height %4 }");
  94.     logDbg(m.arg(i).arg(g->getMean()).arg(g->getSigma()).arg(g->getHeight()));
  95.   }
  96. }
  97. //---------------------------------------------------------------------------
  98.  
  99. /**
  100.  * @brief Stop criterium
  101.  * @param x points
  102.  * @param hs table of posterior probabilities
  103.  * @return value
  104.  */
  105. double ExpectationMaximization::criterium(const Points1D &x,
  106.                                           const Array2D<double> &hs) {
  107.  
  108.   // Counts
  109.   int gaussCount = hs.colCount();
  110.   int pointCount = hs.rowCount();
  111.  
  112.   // Calculate criterial function
  113.   double crit = 0;
  114.   for (int i = 0; i < pointCount; i++) {
  115.     for (int j = 0; j < gaussCount; j++) {
  116.       Gauss1D* g = gaussians.at(j);
  117.       crit += hs.get(i, j) * log(g->getHeight() * g->eval(x.at(i)));
  118.     }
  119.   }
  120.   return crit;
  121. }
  122. //---------------------------------------------------------------------------
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement