Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- function [coeffs, mean_loss] = fit_parabola(coeffs, x, y, visual)
- % --------------- Do Not Edit this Part!!! ---------------
- % Setup
- iter = 0;
- learning_rate = 1e-4;
- plot_every_iters = 5;
- max_iterations = 1e6;
- mean_loss = 1;
- if visual
- figure();
- hold on;
- plot(x, y, 'b', 'LineWidth', 2);
- title('True Parabola vs. Our Predictions');
- xlabel('X Axis [N.U.]');
- ylabel('Y Axis [N.U.]');
- end
- % Training loop
- while(mean_loss > 1e-6)
- % --------------- Edit From Here ---------------
- predictions = coeffs(1) * x.^2 + coeffs(2) * x + coeffs(3);
- mean_loss = mean(predictions-y).^2;
- gradients = auto_diff_p2(x, y, predictions);
- coeffs = coeffs - learning_rate * gradients;
- % --------------- Do Not Edit From Here!!! ---------------
- if visual
- % Print the current mean error & plot true parabola vs. the current prediction
- disp(['The mean loss in iteration #', num2str(iter + 1), ' is: ',...
- num2str(mean_loss), '.']);
- if iter == 20
- plot_every_iters = 100;
- elseif iter == 1000
- plot_every_iters = 10000;
- end
- if mod(iter, plot_every_iters) == 0
- plot(x, predictions, 'r--')
- legend('True Parabola', 'Our Current Prediction');
- drawnow;
- end
- end
- iter = iter + 1;
- end
- % Plot the true parabola vs. the final prediction
- if visual
- close all;
- figure();
- hold on;
- plot(x, y, 'b', 'LineWidth', 2);
- plot(x, predictions, 'r--x')
- title('True Parabola vs. Our Predictions');
- xlabel('X Axis [N.U.]');
- ylabel('Y Axis [N.U.]');
- legend('True Parabola', 'Our Final Prediction');
- end
- end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement