SHARE
TWEET

Untitled

a guest May 19th, 2019 77 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. """
  2. Diagnostic plots for linear regression in Python.
  3. """
  4.  
  5. import matplotlib
  6. import numpy as np
  7. import statsmodels.api as sm
  8. import scipy.stats as stats
  9. import matplotlib.pyplot as plt
  10. import statsmodels.formula.api as smf
  11.  
  12. fs=20 # Set it if you want, but this will be overridden later
  13.  
  14. def abline(slope, intercept):
  15.     """Plot a line from slope and intercept"""
  16.     axes = plt.gca()
  17.     x_vals = np.array(axes.get_xlim())
  18.     y_vals = intercept + slope * x_vals
  19.     plt.plot(x_vals, y_vals, '--', c='r')
  20.  
  21. def qqplot(lm_fit):
  22.     resid = lm_fit.resid
  23.     fig = sm.qqplot(resid, alpha=0.5)
  24.     abline(1, 0)
  25.     plt.xlabel('Theoretical quantiles', fontsize=fs)
  26.     plt.ylabel("Sample quantiles", fontsize=fs)
  27.     plt.title("Normal Q-Q", fontsize=fs)
  28.     return plt
  29.  
  30. def cooks_distance(lm_fit):
  31.     influences = lm_fit.get_influence()
  32.     c, p = influences.cooks_distance
  33.     plt.stem(np.arange(len(c)), c, markerfmt=',')
  34.     plt.xlabel('Observation', fontsize=fs)
  35.     plt.ylabel("Cook's distance", fontsize=fs)
  36.     plt.title("Cook's distance", fontsize=fs)
  37.     return plt
  38.  
  39. def residual_plot(lm_fit):
  40.     fits = lm_fit.fittedvalues
  41.     influences = lm_fit.get_influence()
  42.     rstd = influences.resid_studentized_external
  43.     plt.scatter(fits, rstd, alpha=0.5, color='blue')
  44.     axes=plt.gca()
  45.     x_vals = np.array(axes.get_xlim())
  46.     plt.hlines([-2, 0, 2], x_vals[0], x_vals[1],
  47.             colors=['r', 'k', 'r'], linestyles=['--', '--', '--'])
  48.     plt.xlabel('Fitted values', fontsize=fs)
  49.     plt.ylabel('Studentized residuals', fontsize=fs)
  50.     plt.title('Residuals vs Fitted', fontsize=fs)
  51.     return plt
  52.  
  53. def leverage(lm_fit):
  54.     fig = sm.graphics.influence_plot(lm_fit, criterion="cooks")
  55.     axes=plt.gca()
  56.     x_vals = np.array(axes.get_xlim())
  57.     plt.hlines([-2, 0, 2], x_vals[0], x_vals[1],
  58.             colors=['r', 'k', 'r'], linestyles=['--', '-', '--'])
  59.     plt.xlabel('Leverage', fontsize=fs)
  60.     plt.ylabel('Studentized residuals', fontsize=fs)
  61.     plt.title('Residuals vs Leverage', fontsize=fs)
  62.     return plt
  63.  
  64. def plot_diagnostics(lm_fit):
  65.     residual_plot(lm_fit).show()
  66.     qqplot(lm_fit).show()
  67.     cooks_distance(lm_fit).show()
  68.     leverage(lm_fit).show()
  69.     return None
  70.    
  71. if __name__=='__main__':
  72.     params = {
  73.         'axes.titlesize': '22',
  74.         'axes.labelsize': '20',
  75.         'xtick.labelsize':'18',
  76.         'ytick.labelsize':'18'}
  77.     matplotlib.rcParams.update(params)
  78.     duncan_prestige = sm.datasets.get_rdataset("Duncan", "carData").data
  79.     out = smf.ols('income ~ prestige', data=duncan_prestige).fit()
  80.     plot_diagnostics(out)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top