Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class LinearRegression:
- def get_loss(self, preds, y):
- """
- @param preds: предсказания модели
- @param y: истиные значения
- @return mse: значение MSE на переданных данных
- """
- # возьмите средний квадрат ошибки по всем выходным переменным
- # то есть сумму квадратов ошибки надо поделить на количество_элементов * количество_таргетов
- return np.mean((preds - y)**2)
- def init_weights(self, input_size, output_size):
- """
- Инициализирует параметры модели
- W - матрица размерности (input_size, output_size)
- инициализируется рандомными числами из
- uniform распределения (torch.rand())
- b - вектор размерности (1, output_size)
- инициализируется нулями
- """
- torch.manual_seed(0) #необходимо для воспроизводимости результатов
- self.W = torch.rand(input_size, output_size, requires_grad=True)
- self.b = torch.zeros(1, output_size, requires_grad=True)
- def fit(self, X, y, num_epochs=1000, lr=0.001):
- """
- Обучение модели линейной регрессии методом градиентного спуска
- @param X: размерности (num_samples, input_shape)
- @param y: размерности (num_samples, output_shape)
- @param num_epochs: количество итераций градиентного спуска
- @param lr: шаг градиентного спуска
- @return metrics: вектор значений MSE на каждом шаге градиентного
- спуска.
- """
- self.init_weights(X.shape[1], y.shape[1])
- metrics = []
- for _ in range(num_epochs):
- # сделайте вычисления градиентов c помощью Pytorch и обновите веса
- # осторожнее, оберните вычитание градиента в
- # with torch.no_grad():
- # #some code
- # иначе во время прибавления градиента к переменной создастся очень много нод в дереве операций
- # и ваши модели в будущем будут падать от нехватки памяти
- preds = self.predict(X)
- loss = self.get_loss(preds, y)
- with torch.no_grad():
- loss.backward()
- self.W -= lr * self.W.grad
- self.b -= lr * self.b.grad
- self.W.zero_grad()
- self.b.zero_grad()
- metrics.append(self.get_loss(preds, y).data)
- return metrics
- def predict(self, X):
- """
- Думаю, тут все понятно. Сделайте свои предсказания :)
- """
- return X @ self.W + self.b
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement