Cassimus

RosklaLogistyczny

Nov 24th, 2025
961
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.76 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.linear_model import LogisticRegression
  4.  
  5. def generuj_zbior(n, seed, stromosc_krzywej, sukces_50_procent):
  6.     gen_liczb = np.random.default_rng(seed)
  7.     czas = gen_liczb.uniform(0,60,n)
  8.     prawdopodobienstwo_sukcesu = 1/(1+np.exp(-stromosc_krzywej*(czas-sukces_50_procent)))
  9.     wynik = gen_liczb.binomial(1, prawdopodobienstwo_sukcesu)
  10.     return czas, wynik
  11.  
  12. zbiory = [
  13.     {"n":100 , "seed":1, "stromosc": 0.2, "sukces50":30, "etykieta":"Zbiór 1" },
  14.     {"n":100 , "seed":1, "stromosc": 0.15, "sukces50":25, "etykieta":"Zbiór 2" },
  15.     {"n":100 , "seed":1, "stromosc": 0.25, "sukces50":35, "etykieta":"Zbiór 3" },
  16.     {"n":100 , "seed":1, "stromosc": 0.1, "sukces50":28, "etykieta":"Zbiór 4" }
  17. ]
  18.  
  19. siatka = np.linspace(0,60,500).reshape(-1,1)
  20.  
  21. fig, axs = plt.subplots(2,2, figsize=(12,10))
  22. axs = axs.ravel()
  23.  
  24. for i, zbior in enumerate(zbiory):
  25.     czas, wynik = generuj_zbior(zbior["n"], zbior["seed"], zbior["stromosc"], zbior["sukces50"])
  26.     osX = czas.reshape(-1,1)
  27.     model = LogisticRegression()
  28.     model.fit(osX,wynik)
  29.  
  30.     predykacja = model.predict_proba(siatka)[:,1] * 100
  31.     # dodanie do zbioru krzywych
  32.  
  33.  
  34.     axs[i].scatter(czas[wynik==0], wynik[wynik==0]*100 , c="red", alpha=0.6, label = "Porażka")
  35.     axs[i].scatter(czas[wynik==1], wynik[wynik==1]*100 , c="green", alpha=0.6, label = "Sukces")
  36.  
  37.     axs[i].plot(siatka, predykacja, "b-",linewidth = 2, label="Regresja logistyczna")
  38.  
  39.     axs[i].set_title(zbior["etykieta"])
  40.     axs[i].set_xlabel("Czas trwania gry [min]")
  41.     axs[i].set_ylabel("Prawdopodobieństwo sukcesu [%]")
  42.     axs[i].grid(True, alpha= 0.3)
  43.     axs[i].legend()
  44.  
  45. plt.suptitle("Regresja logistyczna", fontsize=14)
  46. plt.tight_layout()
  47. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment