Advertisement
Guest User

Untitled

a guest
May 19th, 2019
108
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.51 KB | None | 0 0
  1. """
  2. Diagnostic plots for linear regression in Python.
  3. """
  4.  
  5. import matplotlib
  6. import numpy as np
  7. import statsmodels.api as sm
  8. import scipy.stats as stats
  9. import matplotlib.pyplot as plt
  10. import statsmodels.formula.api as smf
  11.  
  12. fs=20 # Set it if you want, but this will be overridden later
  13.  
  14. def abline(slope, intercept):
  15. """Plot a line from slope and intercept"""
  16. axes = plt.gca()
  17. x_vals = np.array(axes.get_xlim())
  18. y_vals = intercept + slope * x_vals
  19. plt.plot(x_vals, y_vals, '--', c='r')
  20.  
  21. def qqplot(lm_fit):
  22. resid = lm_fit.resid
  23. fig = sm.qqplot(resid, alpha=0.5)
  24. abline(1, 0)
  25. plt.xlabel('Theoretical quantiles', fontsize=fs)
  26. plt.ylabel("Sample quantiles", fontsize=fs)
  27. plt.title("Normal Q-Q", fontsize=fs)
  28. return plt
  29.  
  30. def cooks_distance(lm_fit):
  31. influences = lm_fit.get_influence()
  32. c, p = influences.cooks_distance
  33. plt.stem(np.arange(len(c)), c, markerfmt=',')
  34. plt.xlabel('Observation', fontsize=fs)
  35. plt.ylabel("Cook's distance", fontsize=fs)
  36. plt.title("Cook's distance", fontsize=fs)
  37. return plt
  38.  
  39. def residual_plot(lm_fit):
  40. fits = lm_fit.fittedvalues
  41. influences = lm_fit.get_influence()
  42. rstd = influences.resid_studentized_external
  43. plt.scatter(fits, rstd, alpha=0.5, color='blue')
  44. axes=plt.gca()
  45. x_vals = np.array(axes.get_xlim())
  46. plt.hlines([-2, 0, 2], x_vals[0], x_vals[1],
  47. colors=['r', 'k', 'r'], linestyles=['--', '--', '--'])
  48. plt.xlabel('Fitted values', fontsize=fs)
  49. plt.ylabel('Studentized residuals', fontsize=fs)
  50. plt.title('Residuals vs Fitted', fontsize=fs)
  51. return plt
  52.  
  53. def leverage(lm_fit):
  54. fig = sm.graphics.influence_plot(lm_fit, criterion="cooks")
  55. axes=plt.gca()
  56. x_vals = np.array(axes.get_xlim())
  57. plt.hlines([-2, 0, 2], x_vals[0], x_vals[1],
  58. colors=['r', 'k', 'r'], linestyles=['--', '-', '--'])
  59. plt.xlabel('Leverage', fontsize=fs)
  60. plt.ylabel('Studentized residuals', fontsize=fs)
  61. plt.title('Residuals vs Leverage', fontsize=fs)
  62. return plt
  63.  
  64. def plot_diagnostics(lm_fit):
  65. residual_plot(lm_fit).show()
  66. qqplot(lm_fit).show()
  67. cooks_distance(lm_fit).show()
  68. leverage(lm_fit).show()
  69. return None
  70.  
  71. if __name__=='__main__':
  72. params = {
  73. 'axes.titlesize': '22',
  74. 'axes.labelsize': '20',
  75. 'xtick.labelsize':'18',
  76. 'ytick.labelsize':'18'}
  77. matplotlib.rcParams.update(params)
  78. duncan_prestige = sm.datasets.get_rdataset("Duncan", "carData").data
  79. out = smf.ols('income ~ prestige', data=duncan_prestige).fit()
  80. plot_diagnostics(out)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement