Advertisement
Urchien

scatter_plot

Jul 12th, 2023
985
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.83 KB | None | 0 0
  1. def get_scatter(X, y, title, svd=True):
  2.     """
  3.    Функция для визуализации данных на плоскости с использованием метода SVD.
  4.  
  5.    Параметры:
  6.    X - матрица признаков размерности (n_samples, n_features), где n_samples - количество объектов, а n_features - количество признаков.
  7.    y - вектор меток классов размерности (n_samples).
  8.    title - заголовок графика.
  9.    svd - флаг, указывающий, нужно ли использовать метод SVD для снижения размерности данных (по умолчанию True).
  10.  
  11.    Возвращаемое значение:
  12.    График с точками на плоскости, где каждая точка представляет объект из набора данных,
  13.    цвет точки соответствует метке класса, а размер точки зависит от значения вектора y
  14.    (если значение метки класса равно 1, то размер точки равен 2, в противном случае равен 1).
  15.    """
  16.     if svd:
  17.         # Снижение размерности данных до 2 компонент с помощью метода SVD
  18.         svd = TruncatedSVD(n_components=2)
  19.         X = svd.fit_transform(X)
  20.  
  21.     # Задание размеров точек на графике в зависимости от метки класса
  22.     sizes = [2 if label == 1 else 1 for label in y]
  23.  
  24.     # Создание объектов фигуры и осей графика с заданным размером
  25.     fig, ax = plt.subplots(figsize=(15, 8))
  26.     # Рисование точек на графике
  27.     scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.get_cmap('RdYlBu_r'), s=sizes)
  28.     # Задание заголовка графика, а также скрытие делений на осях
  29.     ax.set_title(title)
  30.     ax.set_xticks([])
  31.     ax.set_yticks([])
  32.  
  33.     # Создание списка объектов типа Line2D, представляющих элементы легенды на графике
  34.     legend_elements = [
  35.         plt.Line2D([0], [0], marker='o', color='w', markersize=8,
  36.                    markerfacecolor=scatter.get_cmap()(1.0), label='Токсичные комментарии'),
  37.         plt.Line2D([0], [0], marker='o', color='w', markersize=8,
  38.                    markerfacecolor=scatter.get_cmap()(0.0), label='Нормальные комментарии')
  39.     ]
  40.     # Добавление легенды на график
  41.     ax.legend(handles=legend_elements)
  42.     # Отображение графика
  43.     plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement