SHARE
TWEET

Untitled

a guest Feb 25th, 2020 80 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #pragma once
  2. #include <QWidget>
  3. #include <QPaintEvent>
  4. #include <QPainter>
  5. #include <QKeyEvent>
  6. #include <QMouseEvent>
  7. #include <QtAlgorithms>
  8. #include <QDebug>
  9. #include <vector>
  10. #include <algorithm>
  11.  
  12. struct KNNPoint
  13. {
  14.     int category; // clasele punctelor se vor numi categorii pentru a evita confuzia cu clasele din c++ OOP
  15.     QPoint pos;
  16.     QColor color; // culoarea va depinde de categoria punctului
  17.  
  18.     KNNPoint(int x = 0, int y = 0, int _category = 0)
  19.         :pos(x, y), category(_category)
  20.     {
  21.     }
  22. };
  23.  
  24. // distanta catre un anume punct din plan
  25. struct DistanceTo
  26. {
  27.     KNNPoint destinationPoint;
  28.     double distance;
  29.     double weight;
  30. };
  31.  
  32. class KNN : public QWidget
  33. {
  34.     int k; // numarul de vecini
  35.     int nrCateg;
  36.     int pointRadius; // cat de mare este fiecare punct generat pe ecran
  37.  
  38.     std::vector<KNNPoint> points;
  39.     std::vector<QColor> categColors;
  40.  
  41. public:
  42.     KNN(QWidget *parent = 0) : QWidget(parent), k(3), nrCateg(2), pointRadius(6)
  43.     {
  44.         for (int i = 0; i < nrCateg; i++)
  45.             // generam cate o culoare pentru fiecare categorie
  46.             categColors.push_back(QColor::fromRgb(rand() % 256, rand() % 256, rand() % 256));
  47.  
  48.         genPoints(16);
  49.     }
  50.  
  51.     void genPoints(int nrPoints) // genereaza aleator un numar de "nr" puncte pe ecran
  52.     {
  53.         for (int i = 0; i < nrPoints; i++)
  54.         {
  55.             points.push_back(KNNPoint(rand() % width(), rand() % height(), rand() % nrCateg));
  56.         }
  57.     }
  58.  
  59.     void paintEvent(QPaintEvent *e)
  60.     {
  61.         QPainter painter(this);
  62.  
  63.         for (auto &p : points)
  64.         {
  65.             painter.setBrush(categColors[p.category]);
  66.             painter.drawEllipse(p.pos, pointRadius, pointRadius);
  67.         }
  68.     }
  69.  
  70.     double euclidianDistanceSquared(KNNPoint p1, KNNPoint p2) // calculeaza patratul distantei euclidiene
  71.     {
  72.         return (p1.pos.x() - p2.pos.x()) * (p1.pos.x() - p2.pos.x())
  73.             + (p1.pos.y() - p2.pos.y()) * (p1.pos.y() - p2.pos.y());
  74.     }
  75.  
  76.     // comparator pentru functia de sortare
  77.     static bool lessThan(const DistanceTo& p1, const DistanceTo& p2)
  78.     {
  79.         return (p1.distance < p2.distance);
  80.     }
  81.  
  82.     // metoda care calculeaza toate distantele de la punctul trimis ca parametru catre celelalte
  83.     QVector<DistanceTo> getDistancesTo(KNNPoint point)
  84.     {
  85.         QVector<DistanceTo> distances;
  86.         for (auto &p : points)
  87.         {
  88.             DistanceTo distance = {p, euclidianDistanceSquared(p, point), 0};
  89.             distance.weight = 1. / distance.distance; // ponderea = 1 / d^2
  90.             distances.push_back(distance);
  91.         }
  92.         return distances;
  93.     }
  94.  
  95.     // clasificarea fara ponderi (cu distante)
  96.     void clasificareFaraPonderi(QMouseEvent *e)
  97.     {
  98.         // cream noul punct
  99.         KNNPoint newPoint (e->x(), e->y());
  100.  
  101.         // calculam distantele catre celelalte puncte
  102.         QVector<DistanceTo> distances = getDistancesTo(newPoint);
  103.  
  104.         // sortam ascendent distantele
  105.         qSort(distances.begin(), distances.end(), KNN::lessThan);
  106.  
  107.         // consideram primele k puncte (cele mai apropiate)
  108.         QVector<DistanceTo> firstKPoints;
  109.         for (int i = 0; i < k; ++i)
  110.         {
  111.             firstKPoints.push_back(distances.at(i));
  112.         }
  113.  
  114.         // determinam numarul vecinilor din fiecare clasa
  115.         QVector<int> nrPointsCateg = QVector<int>(nrCateg);
  116.         for (int i = 0; i < nrCateg; ++i)
  117.         {
  118.             nrPointsCateg[i] = 0;
  119.         }
  120.         for(auto &p : firstKPoints)
  121.         {
  122.             nrPointsCateg[p.destinationPoint.category]++;
  123.         }
  124.  
  125.         // aflam clasa cu numarul maxim de vecini
  126.         int categMax = 0;
  127.         int nrMaxVecini = nrPointsCateg[0];
  128.         for (int i = 1; i < nrCateg; ++i)
  129.         {
  130.             if (nrPointsCateg[i] > nrMaxVecini)
  131.             {
  132.                 nrMaxVecini = nrPointsCateg[i];
  133.                 categMax = i;
  134.             }
  135.         }
  136.  
  137.         // categorisim punctul si il adaugam
  138.         newPoint.category = categMax;
  139.         points.push_back(newPoint);
  140.     }
  141.  
  142.     // clasificarea cu ponderi (1 / d^2)
  143.     void clasificareCuPonderi(QMouseEvent *e)
  144.     {
  145.         // cream noul punct
  146.         KNNPoint newPoint (e->x(), e->y());
  147.  
  148.         // calculam distantele catre celelalte puncte
  149.         QVector<DistanceTo> distances = getDistancesTo(newPoint);
  150.  
  151.         // sortam ascendent distantele
  152.         qSort(distances.begin(), distances.end(), KNN::lessThan);
  153.  
  154.         // consideram primele k puncte (cele mai apropiate)
  155.         QVector<DistanceTo> firstKPoints;
  156.         for (int i = 0; i < k; ++i)
  157.         {
  158.             firstKPoints.push_back(distances.at(i));
  159.         }
  160.  
  161.         // determinam sumarea ponderilor pe clase
  162.         QVector<double> weightsCategs = QVector<double>(nrCateg);
  163.         for (int i = 0; i < nrCateg; ++i)
  164.         {
  165.             weightsCategs[i] = 0;
  166.         }
  167.         for(auto &p : firstKPoints)
  168.         {
  169.             weightsCategs[p.destinationPoint.category] += p.weight;
  170.         }
  171.  
  172.         // aflam clasa cu suma maxima de ponderi
  173.         int categMax = 0;
  174.         double maxWeightSum = weightsCategs[0];
  175.         for (int i = 1; i < nrCateg; ++i)
  176.         {
  177.             if (weightsCategs[i] > maxWeightSum)
  178.             {
  179.                 maxWeightSum = weightsCategs[i];
  180.                 categMax = i;
  181.             }
  182.         }
  183.  
  184.         // categorisim punctul si il adaugam
  185.         newPoint.category = categMax;
  186.         points.push_back(newPoint);
  187.     }
  188.  
  189.     // cross-validation
  190.  
  191.     void mousePressEvent(QMouseEvent *e)
  192.     {
  193.         switch (e->button())
  194.         {
  195.         case Qt::LeftButton:
  196.             // clasificareFaraPonderi(e);
  197.             clasificareCuPonderi(e);
  198.  
  199.             update();
  200.             break;
  201.         };
  202.     }
  203.  
  204. };
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top