Advertisement
mystis

lpm-vs-logit

May 21st, 2020
387
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.01 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3.  
  4. ### Imports ###
  5.  
  6. ## Requires to install matplotlib, numpy, pandas, statsmodels and palettable ##
  7.  
  8. import matplotlib.pyplot as plt  # Plotting
  9. import numpy as np  # Linear algebra and random data generation
  10. import pandas as pd  # I/O and data manipulation
  11. import statsmodels.api as sm  # Regressions and summaries
  12. import statsmodels.formula.api as smf  # Wrapper for nice formulas
  13. from palettable.wesanderson import GrandBudapest1_4  # Pretty color schemes
  14. cmap = GrandBudapest1_4.mpl_colors
  15.  
  16. from numpy.random import default_rng
  17. rng = default_rng()
  18.  
  19. import matplotlib
  20. matplotlib.rcParams['mathtext.fontset'] = 'stix'
  21. matplotlib.rcParams['font.family'] = 'STIXGeneral'
  22.  
  23.  
  24. def gen_dep_data(N=300, outcome_pos_ratio=0.5):
  25.     """Generate random data related to the outcome variable
  26.    with different strengths of relationships. To do that, I generate a
  27.    uniformly distributed array [0, 1) of size N. I then sample from the two same
  28.    distributions with different parameters depending on the outcome and the
  29.    strength of the relationship.
  30.    """
  31.     rand_arr = rng.uniform(size=N)
  32.     y = rng.choice([0, 1],
  33.                    size=N,
  34.                    p=[1 - outcome_pos_ratio, outcome_pos_ratio])
  35.  
  36.     λ1, λ2 = 30, 20
  37.     exp1 = rng.exponential(λ1, size=N)
  38.     exp2 = rng.exponential(λ2, size=N)
  39.     weak_rel = 0.2
  40.  
  41.     weak_rel_exp = np.where((rand_arr < weak_rel) & (y == 1), exp1, exp2)
  42.     k1, θ1 = 2, 5
  43.     k2, θ2 = 1, 6
  44.     med_rel = 0.5
  45.     gamma1 = rng.gamma(k1, θ1, size=N)
  46.     gamma2 = rng.gamma(k2, θ2, size=N)
  47.     med_rel_gamma = np.where((rand_arr < med_rel) & (y == 1), gamma1, gamma2)
  48.  
  49.     μ1, σ1 = 10, 10
  50.     μ2, σ2 = 5, 20
  51.     gauss1 = rng.normal(μ1, σ1, size=N)
  52.     gauss2 = rng.normal(μ2, σ2, size=N)
  53.     med_rel_gauss = np.where((rand_arr < med_rel) & (y == 1), gauss1, gauss2)
  54.     d1, d2 = 3, 6
  55.     d3, d4 = 5, 8
  56.     strong_rel = 0.8
  57.     f_dist1 = rng.f(d1, d2, size=N)
  58.     f_dist2 = rng.f(d3, d3, size=N)
  59.     strong_rel_fdist = np.where((rand_arr < strong_rel) & (y == 1), f_dist1,
  60.                                 f_dist2)
  61.  
  62.     unrelated = rng.choice([0, 1], size=N)
  63.  
  64.     data = {
  65.         "y": y,
  66.         "weak_related_exp": weak_rel_exp,
  67.         "med_related_gamma": med_rel_gamma,
  68.         "med_related_gauss": med_rel_gauss,
  69.         "strong_related_fdist": strong_rel_fdist,
  70.         "unrelated": unrelated
  71.     }
  72.     df = pd.DataFrame(data)
  73.  
  74.     return df
  75.  
  76.  
  77. N = 600
  78. df = gen_dep_data(N)
  79.  
  80.  
  81. def gen_formula(df, dep_var_name="y"):
  82.     "Create formula for a given dataframe because I'm lazy"
  83.     rhs = "".join(f"{c} + " for c in df.columns if c != "y").rstrip('+ ')
  84.     formula = dep_var_name + " ~ " + rhs
  85.     return formula
  86.  
  87.  
  88. formula = gen_formula(df)
  89.  
  90. lpm = smf.ols(formula, data=df).fit()
  91.  
  92. logreg = smf.logit(formula, data=df).fit(disp=False)  # no convergence message
  93.  
  94. lpm_params = lpm.params
  95. logreg_params = logreg.params
  96.  
  97. col_names = {0: "Linear Probability Model", 1: "Logit Model"}
  98. results = pd.DataFrame((lpm_params, logreg_params)).T.rename(columns=col_names)
  99.  
  100. m1 = lpm.predict()
  101. m2 = sm.Logit(df["y"].values, m1).fit(disp=False).predict()
  102. m3 = logreg.predict()
  103.  
  104. # Make m1 behave
  105.  
  106. m1[m1 < 0], m1[m1 > 1] = 0, 1
  107.  
  108. random = rng.uniform(size=N)
  109.  
  110.  
  111. def pretty_plots(save=False):
  112.     fig, ax = plt.subplots(nrows=1,
  113.                            ncols=3,
  114.                            sharey=True,
  115.                            figsize=(12, 4),
  116.                            dpi=300)
  117.     for idx, m in enumerate([m1, m2, m3]):
  118.         ax[idx].hist(m, bins=40, color=cmap[idx])
  119.         ax[idx].yaxis.set_tick_params(labelbottom=True)
  120.     for i, j in enumerate(
  121.         ["LPM predictions", "y ~ LPM preds", "Logit predictions"]):
  122.         ax[i].set_title(j, pad=10)
  123.     fig.suptitle("Distributions of predicted probabilities",
  124.                  fontsize=16,
  125.                  y=1.05)
  126.     if save:
  127.         plt.savefig("img/lpm-vs-logit.png", dpi=300, bbox_inches="tight")
  128.     plt.show()
  129.  
  130.  
  131. pretty_plots()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement