Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """
- Diagnostic plots for linear regression in Python.
- """
- import matplotlib
- import numpy as np
- import statsmodels.api as sm
- import scipy.stats as stats
- import matplotlib.pyplot as plt
- import statsmodels.formula.api as smf
- fs=20 # Set it if you want, but this will be overridden later
- def abline(slope, intercept):
- """Plot a line from slope and intercept"""
- axes = plt.gca()
- x_vals = np.array(axes.get_xlim())
- y_vals = intercept + slope * x_vals
- plt.plot(x_vals, y_vals, '--', c='r')
- def qqplot(lm_fit):
- resid = lm_fit.resid
- fig = sm.qqplot(resid, alpha=0.5)
- abline(1, 0)
- plt.xlabel('Theoretical quantiles', fontsize=fs)
- plt.ylabel("Sample quantiles", fontsize=fs)
- plt.title("Normal Q-Q", fontsize=fs)
- return plt
- def cooks_distance(lm_fit):
- influences = lm_fit.get_influence()
- c, p = influences.cooks_distance
- plt.stem(np.arange(len(c)), c, markerfmt=',')
- plt.xlabel('Observation', fontsize=fs)
- plt.ylabel("Cook's distance", fontsize=fs)
- plt.title("Cook's distance", fontsize=fs)
- return plt
- def residual_plot(lm_fit):
- fits = lm_fit.fittedvalues
- influences = lm_fit.get_influence()
- rstd = influences.resid_studentized_external
- plt.scatter(fits, rstd, alpha=0.5, color='blue')
- axes=plt.gca()
- x_vals = np.array(axes.get_xlim())
- plt.hlines([-2, 0, 2], x_vals[0], x_vals[1],
- colors=['r', 'k', 'r'], linestyles=['--', '--', '--'])
- plt.xlabel('Fitted values', fontsize=fs)
- plt.ylabel('Studentized residuals', fontsize=fs)
- plt.title('Residuals vs Fitted', fontsize=fs)
- return plt
- def leverage(lm_fit):
- fig = sm.graphics.influence_plot(lm_fit, criterion="cooks")
- axes=plt.gca()
- x_vals = np.array(axes.get_xlim())
- plt.hlines([-2, 0, 2], x_vals[0], x_vals[1],
- colors=['r', 'k', 'r'], linestyles=['--', '-', '--'])
- plt.xlabel('Leverage', fontsize=fs)
- plt.ylabel('Studentized residuals', fontsize=fs)
- plt.title('Residuals vs Leverage', fontsize=fs)
- return plt
- def plot_diagnostics(lm_fit):
- residual_plot(lm_fit).show()
- qqplot(lm_fit).show()
- cooks_distance(lm_fit).show()
- leverage(lm_fit).show()
- return None
- if __name__=='__main__':
- params = {
- 'axes.titlesize': '22',
- 'axes.labelsize': '20',
- 'xtick.labelsize':'18',
- 'ytick.labelsize':'18'}
- matplotlib.rcParams.update(params)
- duncan_prestige = sm.datasets.get_rdataset("Duncan", "carData").data
- out = smf.ols('income ~ prestige', data=duncan_prestige).fit()
- plot_diagnostics(out)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement