Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /**
- * @brief Performs EM
- * @param f function
- * @param x input points
- * @return critirial function value
- */
- double ExpectationMaximization::maximize(SumGauss1D* f, const Points1D &pts) {
- // Delta value
- const double DELTA = 1e-5;
- // Cache gaussians
- this->gaussians = f->getGaussians();
- // Perform algorithm
- double prevCrit = 0, crit = 0;
- do {
- Array2D<double> hs = estep(pts);
- mstep(pts, hs);
- prevCrit = crit;
- crit = criterium(pts, hs);
- QString m = QObject::tr("Checking criterium previous: %1, current %2...");
- logDbg(m.arg(prevCrit).arg(crit));
- } while (fabs(crit - prevCrit) > DELTA * fabs(prevCrit));
- return crit;
- }
- //---------------------------------------------------------------------------
- /**
- * @brief Expectation step (finds probabilities that points belong to gaussians)
- * @param x points
- * @return table of posterior probabilities
- */
- Array2D<double> ExpectationMaximization::estep(const Points1D &x) {
- // Configure bayes
- logDbg(QObject::tr("Performing E-step..."));
- bayes->setGaussians(gaussians);
- // Calculate probabilities that i-th point belongs to j-th class
- Array2D<double> hs(x.size(), gaussians.size());
- for (int i = 0; i < x.size(); i++) {
- for (int j = 0; j < gaussians.size(); j++) {
- bayes->setIndex(j);
- hs.set(i, j, bayes->eval(x.at(i)));
- }
- }
- return hs;
- }
- //---------------------------------------------------------------------------
- /**
- * @brief Maximization step (updates gaussians based on probabilities)
- * @param x points
- * @param hs table of posterior probabilities
- */
- void ExpectationMaximization::mstep(const Points1D &x,
- const Array2D<double> &hs) {
- // Check dimensions
- logDbg(QObject::tr("Performing M-step..."));
- if (hs.colCount() != gaussians.size()) {
- logErr(QObject::tr("mStep: invalid number of gaussians."));
- return;
- }
- if (hs.rowCount() != x.size()) {
- logErr(QObject::tr("mStep: invalid number of points."));
- return;
- }
- // Counts
- int gaussCount = hs.colCount();
- int pointCount = hs.rowCount();
- // Iterate through all components
- for (int i = 0; i < gaussCount; i++) {
- // Current gaussian
- Gauss1D* g = gaussians.at(i);
- // Sum of posterior probabilities
- QVector<double> probs = hs.getCol(i);
- double sum = MathUtils::sum(probs);
- // Update parameters of current gaussian
- g->setHeight(1.0 / pointCount * sum);
- g->setMean(MathUtils::dot(x, probs) / sum);
- QVector<double> v = MathUtils::sub(x, g->getMean());
- g->setSigma(sqrt(MathUtils::dot(v, v) / sum));
- // Debug
- QString m = QObject::tr("Component %1 { mean: %2, sigma: %3, height %4 }");
- logDbg(m.arg(i).arg(g->getMean()).arg(g->getSigma()).arg(g->getHeight()));
- }
- }
- //---------------------------------------------------------------------------
- /**
- * @brief Stop criterium
- * @param x points
- * @param hs table of posterior probabilities
- * @return value
- */
- double ExpectationMaximization::criterium(const Points1D &x,
- const Array2D<double> &hs) {
- // Counts
- int gaussCount = hs.colCount();
- int pointCount = hs.rowCount();
- // Calculate criterial function
- double crit = 0;
- for (int i = 0; i < pointCount; i++) {
- for (int j = 0; j < gaussCount; j++) {
- Gauss1D* g = gaussians.at(j);
- crit += hs.get(i, j) * log(g->getHeight() * g->eval(x.at(i)));
- }
- }
- return crit;
- }
- //---------------------------------------------------------------------------
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement