Advertisement
Guest User

Untitled

a guest
Jan 5th, 2024
132
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.61 KB | None | 0 0
  1. # %%
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.linear_model import LogisticRegression
  5.  
  6. D = 20  # instrinsic
  7. repeats = 20
  8. d = D * repeats  # data dimension
  9.  
  10. layers = 10
  11. norm = 1
  12. weight_matrix_in = [norm * np.random.randn(d, d) / np.sqrt(d) for _ in range(layers)]
  13. weight_matrix_out = [norm * np.random.randn(d, d) / np.sqrt(d) for _ in range(layers)]
  14. activation = lambda x: np.maximum(x, 0)
  15.  
  16.  
  17. def forward(x):
  18.     for i in range(layers):
  19.         x = x + weight_matrix_out[i] @ activation(weight_matrix_in[i] @ x)
  20.     return x
  21.  
  22.  
  23. N = 10000
  24. # get random points in {-1, 1}^D
  25. X = np.random.choice([-1, 1], size=(N, D))
  26. # %%
  27. x = np.concatenate([X] * repeats, axis=1)
  28. Z = forward(x.T).T
  29. # %%
  30. from matplotlib import pyplot as plt
  31.  
  32. # see if you can still classify
  33. n = 30
  34. accs = []
  35. for _ in range(n):
  36.     c = np.random.randint(D)
  37.     labels = np.sign(X[:, c])
  38.     yt, yv, zt, zv = train_test_split(labels, Z, test_size=0.2)
  39.     model = LogisticRegression(max_iter=1000)
  40.     model.fit(zt, yt)
  41.     accs.append(model.score(zv, yv))
  42. plt.hist(accs, bins=20, range=(0, 1))
  43. plt.axvline(0.5, color="black")
  44. plt.title("base features")
  45. # %%
  46. # see if you can classify xors
  47. accs = []
  48. for _ in range(n):
  49.     c1, c2 = np.random.choice(D, size=2, replace=False)
  50.     labels = np.sign(X[:, c1] * X[:, c2])
  51.     yt, yv, zt, zv = train_test_split(labels, Z, test_size=0.2)
  52.     model = LogisticRegression(max_iter=1000)
  53.     model.fit(zt, yt)
  54.     accs.append(model.score(zv, yv))
  55. plt.hist(accs, bins=20, range=(0, 1))
  56. plt.axvline(0.5, color="black")
  57. plt.title("xor features")
  58. # %%
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement