Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pragma once
- #include <QWidget>
- #include <QPaintEvent>
- #include <QPainter>
- #include <QKeyEvent>
- #include <QMouseEvent>
- #include <QtAlgorithms>
- #include <QDebug>
- #include <vector>
- #include <algorithm>
- struct KNNPoint
- {
- int category; // clasele punctelor se vor numi categorii pentru a evita confuzia cu clasele din c++ OOP
- QPoint pos;
- QColor color; // culoarea va depinde de categoria punctului
- KNNPoint(int x = 0, int y = 0, int _category = 0)
- :pos(x, y), category(_category)
- {
- }
- };
- // distanta catre un anume punct din plan
- struct DistanceTo
- {
- KNNPoint destinationPoint;
- double distance;
- double weight;
- };
- class KNN : public QWidget
- {
- int k; // numarul de vecini
- int nrCateg;
- int pointRadius; // cat de mare este fiecare punct generat pe ecran
- std::vector<KNNPoint> points;
- std::vector<QColor> categColors;
- public:
- KNN(QWidget *parent = 0) : QWidget(parent), k(3), nrCateg(2), pointRadius(6)
- {
- for (int i = 0; i < nrCateg; i++)
- // generam cate o culoare pentru fiecare categorie
- categColors.push_back(QColor::fromRgb(rand() % 256, rand() % 256, rand() % 256));
- genPoints(16);
- }
- void genPoints(int nrPoints) // genereaza aleator un numar de "nr" puncte pe ecran
- {
- for (int i = 0; i < nrPoints; i++)
- {
- points.push_back(KNNPoint(rand() % width(), rand() % height(), rand() % nrCateg));
- }
- }
- void paintEvent(QPaintEvent *e)
- {
- QPainter painter(this);
- for (auto &p : points)
- {
- painter.setBrush(categColors[p.category]);
- painter.drawEllipse(p.pos, pointRadius, pointRadius);
- }
- }
- double euclidianDistanceSquared(KNNPoint p1, KNNPoint p2) // calculeaza patratul distantei euclidiene
- {
- return (p1.pos.x() - p2.pos.x()) * (p1.pos.x() - p2.pos.x())
- + (p1.pos.y() - p2.pos.y()) * (p1.pos.y() - p2.pos.y());
- }
- // comparator pentru functia de sortare
- static bool lessThan(const DistanceTo& p1, const DistanceTo& p2)
- {
- return (p1.distance < p2.distance);
- }
- // metoda care calculeaza toate distantele de la punctul trimis ca parametru catre celelalte
- QVector<DistanceTo> getDistancesTo(KNNPoint point)
- {
- QVector<DistanceTo> distances;
- for (auto &p : points)
- {
- DistanceTo distance = {p, euclidianDistanceSquared(p, point), 0};
- distance.weight = 1. / distance.distance; // ponderea = 1 / d^2
- distances.push_back(distance);
- }
- return distances;
- }
- // clasificarea fara ponderi (cu distante)
- void clasificareFaraPonderi(QMouseEvent *e)
- {
- // cream noul punct
- KNNPoint newPoint (e->x(), e->y());
- // calculam distantele catre celelalte puncte
- QVector<DistanceTo> distances = getDistancesTo(newPoint);
- // sortam ascendent distantele
- qSort(distances.begin(), distances.end(), KNN::lessThan);
- // consideram primele k puncte (cele mai apropiate)
- QVector<DistanceTo> firstKPoints;
- for (int i = 0; i < k; ++i)
- {
- firstKPoints.push_back(distances.at(i));
- }
- // determinam numarul vecinilor din fiecare clasa
- QVector<int> nrPointsCateg = QVector<int>(nrCateg);
- for (int i = 0; i < nrCateg; ++i)
- {
- nrPointsCateg[i] = 0;
- }
- for(auto &p : firstKPoints)
- {
- nrPointsCateg[p.destinationPoint.category]++;
- }
- // aflam clasa cu numarul maxim de vecini
- int categMax = 0;
- int nrMaxVecini = nrPointsCateg[0];
- for (int i = 1; i < nrCateg; ++i)
- {
- if (nrPointsCateg[i] > nrMaxVecini)
- {
- nrMaxVecini = nrPointsCateg[i];
- categMax = i;
- }
- }
- // categorisim punctul si il adaugam
- newPoint.category = categMax;
- points.push_back(newPoint);
- }
- // clasificarea cu ponderi (1 / d^2)
- void clasificareCuPonderi(QMouseEvent *e)
- {
- // cream noul punct
- KNNPoint newPoint (e->x(), e->y());
- // calculam distantele catre celelalte puncte
- QVector<DistanceTo> distances = getDistancesTo(newPoint);
- // sortam ascendent distantele
- qSort(distances.begin(), distances.end(), KNN::lessThan);
- // consideram primele k puncte (cele mai apropiate)
- QVector<DistanceTo> firstKPoints;
- for (int i = 0; i < k; ++i)
- {
- firstKPoints.push_back(distances.at(i));
- }
- // determinam sumarea ponderilor pe clase
- QVector<double> weightsCategs = QVector<double>(nrCateg);
- for (int i = 0; i < nrCateg; ++i)
- {
- weightsCategs[i] = 0;
- }
- for(auto &p : firstKPoints)
- {
- weightsCategs[p.destinationPoint.category] += p.weight;
- }
- // aflam clasa cu suma maxima de ponderi
- int categMax = 0;
- double maxWeightSum = weightsCategs[0];
- for (int i = 1; i < nrCateg; ++i)
- {
- if (weightsCategs[i] > maxWeightSum)
- {
- maxWeightSum = weightsCategs[i];
- categMax = i;
- }
- }
- // categorisim punctul si il adaugam
- newPoint.category = categMax;
- points.push_back(newPoint);
- }
- // cross-validation
- void mousePressEvent(QMouseEvent *e)
- {
- switch (e->button())
- {
- case Qt::LeftButton:
- // clasificareFaraPonderi(e);
- clasificareCuPonderi(e);
- update();
- break;
- };
- }
- };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement