Advertisement
hinagawa

ML_1

Sep 30th, 2021
784
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.05 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3.  
  4.  
  5. class Point:
  6.     def __init__(self, x, y, cluster=-1):
  7.         self.x = x
  8.         self.y = y
  9.         self.cluster = cluster
  10.  
  11.  
  12. def dist(a, b):
  13.     return np.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2)
  14.  
  15.  
  16. def rand_points(n):
  17.     points = []
  18.     for i in range(n):
  19.         point = Point(np.random.randint(0, 100), np.random.randint(0, 100))
  20.         points.append(point)
  21.     return points
  22.  
  23.  
  24. def centroids(points, k):
  25.     x_center = np.mean(list(map(lambda p: p.x, points)))
  26.     y_center = np.mean(list(map(lambda p: p.y, points)))
  27.     center = Point(x_center, y_center)
  28.     R = max(map(lambda r: dist(r, center), points))
  29.     centers = []
  30.     for i in range(k):
  31.         x_c = x_center + R * np.cos(2 * np.pi * i / k)
  32.         y_c = y_center + R * np.sin(2 * np.pi * i / k)
  33.         centers.append(Point(x_c, y_c))
  34.     return centers
  35.  
  36.  
  37. def new_centroids(points, k):
  38.     centers = []
  39.     cluster_arr = []
  40.     for i in range(k):
  41.         for point in points:
  42.             if (point.cluster == i):
  43.                 cluster_arr.append(point)
  44.         x_center = np.mean(list(map(lambda p: p.x, cluster_arr)))
  45.         y_center = np.mean(list(map(lambda p: p.y, cluster_arr)))
  46.         centers.append(Point(x_center, y_center))
  47.         cluster_arr = []
  48.     return centers
  49.  
  50.  
  51. def nearest_centroids(points, centroids):
  52.     for point in points:
  53.         min_dist = dist(point, centroids[0])
  54.         point.cluster = 0
  55.         for i in range(len(centroids)):
  56.             temp = dist(point, centroids[i])
  57.             if temp < min_dist:
  58.                 min_dist = temp
  59.                 point.cluster = i
  60.  
  61.  
  62. def color_point(points):
  63.     for point in points:
  64.         if point.cluster == 1:
  65.             plt.scatter(point.x, point.y, c='#84e098')
  66.         elif point.cluster == 2:
  67.             plt.scatter(point.x, point.y, c='#e66ae6')
  68.         elif point.cluster == 3:
  69.             plt.scatter(point.x, point.y, c='#6585c2')
  70.         else:
  71.             plt.scatter(point.x, point.y, c='black')
  72.  
  73.  
  74. def equal_centers(prev_center, centers, k):
  75.     for i in range(k):
  76.         if not (prev_center[i].x == centers[i].x and prev_center[i].y == centers[i].y):
  77.             return 0
  78.     return 1
  79.  
  80.  
  81. if __name__ == "__main__":
  82.     cluster_1 = []
  83.     cluster_2 = []
  84.     cluster_3 = []
  85.     n = 100  # кол-во тчк
  86.     k = 3  # кол-во кластеров
  87.     points = rand_points(n)
  88.     centers = centroids(points, k)
  89.     prev_center = centers
  90.     plt.scatter(list(map(lambda p: p.x, centers)), list(map(lambda p: p.y, centers)), s=100, marker='*', color='r')
  91.     nearest_centroids(points, centers)
  92.     color_point(points)
  93.     plt.show()
  94.     centers = new_centroids(points, k)
  95.     while not equal_centers(prev_center, centers, k):
  96.         plt.scatter(list(map(lambda p: p.x, centers)), list(map(lambda p: p.y, centers)), s=100, marker='*', color='r')
  97.         nearest_centroids(points, centers)
  98.         color_point(points)
  99.         plt.show()
  100.         prev_center = centers
  101.         centers = new_centroids(points, k)
  102.  
  103.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement