Advertisement
Guest User

Untitled

a guest
Feb 25th, 2020
126
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.60 KB | None | 0 0
  1. #pragma once
  2. #include <QWidget>
  3. #include <QPaintEvent>
  4. #include <QPainter>
  5. #include <QKeyEvent>
  6. #include <QMouseEvent>
  7. #include <QtAlgorithms>
  8. #include <QDebug>
  9. #include <vector>
  10. #include <algorithm>
  11.  
  12. struct KNNPoint
  13. {
  14. int category; // clasele punctelor se vor numi categorii pentru a evita confuzia cu clasele din c++ OOP
  15. QPoint pos;
  16. QColor color; // culoarea va depinde de categoria punctului
  17.  
  18. KNNPoint(int x = 0, int y = 0, int _category = 0)
  19. :pos(x, y), category(_category)
  20. {
  21. }
  22. };
  23.  
  24. // distanta catre un anume punct din plan
  25. struct DistanceTo
  26. {
  27. KNNPoint destinationPoint;
  28. double distance;
  29. double weight;
  30. };
  31.  
  32. class KNN : public QWidget
  33. {
  34. int k; // numarul de vecini
  35. int nrCateg;
  36. int pointRadius; // cat de mare este fiecare punct generat pe ecran
  37.  
  38. std::vector<KNNPoint> points;
  39. std::vector<QColor> categColors;
  40.  
  41. public:
  42. KNN(QWidget *parent = 0) : QWidget(parent), k(3), nrCateg(2), pointRadius(6)
  43. {
  44. for (int i = 0; i < nrCateg; i++)
  45. // generam cate o culoare pentru fiecare categorie
  46. categColors.push_back(QColor::fromRgb(rand() % 256, rand() % 256, rand() % 256));
  47.  
  48. genPoints(16);
  49. }
  50.  
  51. void genPoints(int nrPoints) // genereaza aleator un numar de "nr" puncte pe ecran
  52. {
  53. for (int i = 0; i < nrPoints; i++)
  54. {
  55. points.push_back(KNNPoint(rand() % width(), rand() % height(), rand() % nrCateg));
  56. }
  57. }
  58.  
  59. void paintEvent(QPaintEvent *e)
  60. {
  61. QPainter painter(this);
  62.  
  63. for (auto &p : points)
  64. {
  65. painter.setBrush(categColors[p.category]);
  66. painter.drawEllipse(p.pos, pointRadius, pointRadius);
  67. }
  68. }
  69.  
  70. double euclidianDistanceSquared(KNNPoint p1, KNNPoint p2) // calculeaza patratul distantei euclidiene
  71. {
  72. return (p1.pos.x() - p2.pos.x()) * (p1.pos.x() - p2.pos.x())
  73. + (p1.pos.y() - p2.pos.y()) * (p1.pos.y() - p2.pos.y());
  74. }
  75.  
  76. // comparator pentru functia de sortare
  77. static bool lessThan(const DistanceTo& p1, const DistanceTo& p2)
  78. {
  79. return (p1.distance < p2.distance);
  80. }
  81.  
  82. // metoda care calculeaza toate distantele de la punctul trimis ca parametru catre celelalte
  83. QVector<DistanceTo> getDistancesTo(KNNPoint point)
  84. {
  85. QVector<DistanceTo> distances;
  86. for (auto &p : points)
  87. {
  88. DistanceTo distance = {p, euclidianDistanceSquared(p, point), 0};
  89. distance.weight = 1. / distance.distance; // ponderea = 1 / d^2
  90. distances.push_back(distance);
  91. }
  92. return distances;
  93. }
  94.  
  95. // clasificarea fara ponderi (cu distante)
  96. void clasificareFaraPonderi(QMouseEvent *e)
  97. {
  98. // cream noul punct
  99. KNNPoint newPoint (e->x(), e->y());
  100.  
  101. // calculam distantele catre celelalte puncte
  102. QVector<DistanceTo> distances = getDistancesTo(newPoint);
  103.  
  104. // sortam ascendent distantele
  105. qSort(distances.begin(), distances.end(), KNN::lessThan);
  106.  
  107. // consideram primele k puncte (cele mai apropiate)
  108. QVector<DistanceTo> firstKPoints;
  109. for (int i = 0; i < k; ++i)
  110. {
  111. firstKPoints.push_back(distances.at(i));
  112. }
  113.  
  114. // determinam numarul vecinilor din fiecare clasa
  115. QVector<int> nrPointsCateg = QVector<int>(nrCateg);
  116. for (int i = 0; i < nrCateg; ++i)
  117. {
  118. nrPointsCateg[i] = 0;
  119. }
  120. for(auto &p : firstKPoints)
  121. {
  122. nrPointsCateg[p.destinationPoint.category]++;
  123. }
  124.  
  125. // aflam clasa cu numarul maxim de vecini
  126. int categMax = 0;
  127. int nrMaxVecini = nrPointsCateg[0];
  128. for (int i = 1; i < nrCateg; ++i)
  129. {
  130. if (nrPointsCateg[i] > nrMaxVecini)
  131. {
  132. nrMaxVecini = nrPointsCateg[i];
  133. categMax = i;
  134. }
  135. }
  136.  
  137. // categorisim punctul si il adaugam
  138. newPoint.category = categMax;
  139. points.push_back(newPoint);
  140. }
  141.  
  142. // clasificarea cu ponderi (1 / d^2)
  143. void clasificareCuPonderi(QMouseEvent *e)
  144. {
  145. // cream noul punct
  146. KNNPoint newPoint (e->x(), e->y());
  147.  
  148. // calculam distantele catre celelalte puncte
  149. QVector<DistanceTo> distances = getDistancesTo(newPoint);
  150.  
  151. // sortam ascendent distantele
  152. qSort(distances.begin(), distances.end(), KNN::lessThan);
  153.  
  154. // consideram primele k puncte (cele mai apropiate)
  155. QVector<DistanceTo> firstKPoints;
  156. for (int i = 0; i < k; ++i)
  157. {
  158. firstKPoints.push_back(distances.at(i));
  159. }
  160.  
  161. // determinam sumarea ponderilor pe clase
  162. QVector<double> weightsCategs = QVector<double>(nrCateg);
  163. for (int i = 0; i < nrCateg; ++i)
  164. {
  165. weightsCategs[i] = 0;
  166. }
  167. for(auto &p : firstKPoints)
  168. {
  169. weightsCategs[p.destinationPoint.category] += p.weight;
  170. }
  171.  
  172. // aflam clasa cu suma maxima de ponderi
  173. int categMax = 0;
  174. double maxWeightSum = weightsCategs[0];
  175. for (int i = 1; i < nrCateg; ++i)
  176. {
  177. if (weightsCategs[i] > maxWeightSum)
  178. {
  179. maxWeightSum = weightsCategs[i];
  180. categMax = i;
  181. }
  182. }
  183.  
  184. // categorisim punctul si il adaugam
  185. newPoint.category = categMax;
  186. points.push_back(newPoint);
  187. }
  188.  
  189. // cross-validation
  190.  
  191. void mousePressEvent(QMouseEvent *e)
  192. {
  193. switch (e->button())
  194. {
  195. case Qt::LeftButton:
  196. // clasificareFaraPonderi(e);
  197. clasificareCuPonderi(e);
  198.  
  199. update();
  200. break;
  201. };
  202. }
  203.  
  204. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement