Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_scatter(X, y, title, svd=True):
- """
- Функция для визуализации данных на плоскости с использованием метода SVD.
- Параметры:
- X - матрица признаков размерности (n_samples, n_features), где n_samples - количество объектов, а n_features - количество признаков.
- y - вектор меток классов размерности (n_samples).
- title - заголовок графика.
- svd - флаг, указывающий, нужно ли использовать метод SVD для снижения размерности данных (по умолчанию True).
- Возвращаемое значение:
- График с точками на плоскости, где каждая точка представляет объект из набора данных,
- цвет точки соответствует метке класса, а размер точки зависит от значения вектора y
- (если значение метки класса равно 1, то размер точки равен 2, в противном случае равен 1).
- """
- if svd:
- # Снижение размерности данных до 2 компонент с помощью метода SVD
- svd = TruncatedSVD(n_components=2)
- X = svd.fit_transform(X)
- # Задание размеров точек на графике в зависимости от метки класса
- sizes = [2 if label == 1 else 1 for label in y]
- # Создание объектов фигуры и осей графика с заданным размером
- fig, ax = plt.subplots(figsize=(15, 8))
- # Рисование точек на графике
- scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.get_cmap('RdYlBu_r'), s=sizes)
- # Задание заголовка графика, а также скрытие делений на осях
- ax.set_title(title)
- ax.set_xticks([])
- ax.set_yticks([])
- # Создание списка объектов типа Line2D, представляющих элементы легенды на графике
- legend_elements = [
- plt.Line2D([0], [0], marker='o', color='w', markersize=8,
- markerfacecolor=scatter.get_cmap()(1.0), label='Токсичные комментарии'),
- plt.Line2D([0], [0], marker='o', color='w', markersize=8,
- markerfacecolor=scatter.get_cmap()(0.0), label='Нормальные комментарии')
- ]
- # Добавление легенды на график
- ax.legend(handles=legend_elements)
- # Отображение графика
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement