Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from statsmodels.regression.linear_model import RegressionResults
- from statsmodels.genmod.generalized_linear_model import GLMResults
- from statsmodels.robust.robust_linear_model import RLMResults
- from linearmodels.iv.results import IVResults
- from sklearn.linear_model import LinearRegression
- import numpy as np
- import matplotlib.pyplot as plt
- import seaborn as sns
- import statsmodels.api as sm
- from typing import Union, Tuple, Dict, Any
- import os
- def plot(model: Union[RegressionResults, IVResults, GLMResults, RLMResults, LinearRegression],
- y: np.ndarray = None,
- X: np.ndarray = None,
- plot_type: str = 'residual',
- figsize: Tuple[int, int] = (10, 6),
- color: str = 'blue',
- marker: str = 'o',
- save_path: str = None,
- plot_params: Dict[str, Any] = None,
- **kwargs):
- if plot_params is None:
- plot_params = {}
- elif not isinstance(plot_params, dict):
- raise ValueError("plot_params must be a dictionary.")
- # Extract matplotlib and seaborn specific parameters
- fig_params = plot_params.get('figure', {})
- ax_params = plot_params.get('axes', {})
- sns_params = plot_params.get('seaborn', {})
- def plot_residuals(fitted_values, residuals):
- fig, ax = plt.subplots(figsize=figsize, **fig_params)
- ax.scatter(fitted_values, residuals, color=color, marker=marker, **kwargs)
- ax.axhline(0, color='red', linestyle='--')
- ax.set_xlabel('Fitted values')
- ax.set_ylabel('Residuals')
- ax.set_title('Residuals vs Fitted values')
- return fig
- def plot_qq(residuals):
- fig = sm.qqplot(residuals, line='45')
- plt.title('Q-Q Plot')
- return fig
- def plot_leverage_resid2(model):
- fig = sm.graphics.plot_leverage_resid2(model)
- plt.title('Leverage vs. Residuals')
- return fig
- def plot_influence(model):
- fig = sm.graphics.plot_influence(model, criterion="cooks")
- plt.title('Influence Plot')
- return fig
- def plot_cooks(model):
- fig = sm.graphics.influence_plot(model)
- plt.title("Cook's Distance Plot")
- return fig
- def plot_residual_density(residuals):
- fig, ax = plt.subplots(figsize=figsize, **fig_params)
- sns.kdeplot(residuals, shade=True, color=color, ax=ax, **sns_params, **kwargs)
- ax.set_title('Residual Density Plot')
- ax.set_xlabel('Residuals')
- return fig
- # Map plot types to their respective functions for statsmodels models
- statsmodels_plot_funcs = {
- 'residual': plot_residuals,
- 'qq': plot_qq,
- 'leverage': plot_leverage_resid2,
- 'cooks': plot_cooks,
- 'influence': plot_influence,
- 'residual_density': plot_residual_density
- }
- try:
- if isinstance(model, (RegressionResults, IVResults, GLMResults, RLMResults)):
- if not hasattr(model, 'fittedvalues') or not hasattr(model, 'resid'):
- raise AttributeError("Model object does not have necessary attributes 'fittedvalues' or 'resid'. Ensure you are passing a valid statsmodels model.")
- fitted_values = model.fittedvalues
- residuals = model.resid
- plot_func = statsmodels_plot_funcs.get(plot_type)
- if plot_func:
- fig = plot_func(fitted_values, residuals)
- elif plot_type == 'partial_regression':
- fig = plt.figure(figsize=figsize, **fig_params)
- sm.graphics.plot_partregress_grid(model, fig=fig, **plot_params)
- plt.title('Partial Regression Plots')
- else:
- raise ValueError("Unsupported plot_type for statsmodels models. Choose from: 'residual', 'qq', 'leverage', 'cooks', 'influence', 'partial_regression', 'residual_density'.")
- elif isinstance(model, LinearRegression):
- if not isinstance(y, np.ndarray) or not isinstance(X, np.ndarray):
- raise ValueError("y and X must be NumPy arrays for LinearRegression models.")
- y_pred = model.predict(X)
- residuals = y - y_pred
- plot_func = statsmodels_plot_funcs.get(plot_type)
- if plot_func:
- fig = plot_func(y_pred, residuals)
- elif plot_type == 'leverage':
- leverage = (X * np.linalg.pinv(X.T @ X) @ X.T).sum(axis=1)
- fig, ax = plt.subplots(figsize=figsize, **fig_params)
- ax.scatter(leverage, residuals, color=color, marker=marker, **kwargs)
- ax.axhline(0, color='red', linestyle='--')
- ax.set_xlabel('Leverage')
- ax.set_ylabel('Residuals')
- ax.set_title('Leverage vs Residuals')
- elif plot_type == 'partial_regression':
- raise NotImplementedError("Partial regression plots are not implemented for sklearn models.")
- else:
- raise ValueError("Unsupported plot_type for LinearRegression models. Choose from: 'residual', 'qq', 'leverage', 'cooks', 'residual_density'.")
- else:
- raise TypeError("Unsupported model type. Supported types are statsmodels RegressionResults, IVResults, GLMResults, RLMResults, and sklearn LinearRegression.")
- if save_path:
- directory = os.path.dirname(save_path)
- if not os.path.exists(directory):
- os.makedirs(directory)
- fig.savefig(save_path)
- else:
- plt.show()
- return fig
- except ValueError as e:
- print(f"ValueError: {e}. Ensure that the provided 'plot_type' is correct and that 'y' and 'X' are NumPy arrays for LinearRegression models.")
- except TypeError as e:
- print(f"TypeError: {e}. Supported model types are statsmodels RegressionResults, IVResults, GLMResults, RLMResults, and sklearn LinearRegression.")
- except AttributeError as e:
- print(f"AttributeError: {e}. Ensure the model object has the necessary attributes and is a valid statsmodels model.")
- except NotImplementedError as e:
- print(f"NotImplementedError: {e}. This plot type is not implemented for the provided model type.")
- except Exception as e:
- print(f"An unexpected error occurred: {e}")
- # Usage for statsmodels or linearmodels
- # plot(model, plot_type='residual') # Residual plot
- # plot(model, plot_type='qq') # Q-Q plot
- # plot(model, plot_type='leverage') # Leverage plot
- # plot(model, plot_type='cooks') # Cook's distance plot
- # plot(model, plot_type='influence') # Influence plot
- # plot(model, plot_type='partial_regression') # Partial regression plot
- # plot(model, plot_type='residual_density') # Residual density plot
- # Usage for scikit-learn
- # plot(model, y, X, plot_type='residual') # Residual plot
- # plot(model, y, X, plot_type='qq') # Q-Q plot
- # plot(model, y, X, plot_type='leverage') # Leverage plot
- # plot(model, y, X, plot_type='cooks') # Cook's distance plot
- # plot(model, y, X, plot_type='residual_density') # Residual density plot
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement