Advertisement
Guest User

Untitled

a guest
Jun 20th, 2018
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.60 KB | None | 0 0
  1. def train_readout(X, y, Xt, yt, PRINTSTUFF, reg_fact=0):
  2. """
  3. Train and test a readout. Assumes labels are integers (0., ..., N_classes-1).
  4.  
  5. :param X: training data
  6. :param y: training labels
  7. :param Xt: test data
  8. :param yt: test labels
  9. :param PRINTSTUFF: print or not
  10. :param reg_fact: regularization constant, default: 0
  11. :returns: tuple of training and test error
  12. """
  13.  
  14. # add constant component to states for bias
  15. X = np.hstack((np.ones((len(X), 1)), X))
  16. Xt = np.hstack((np.ones((len(Xt), 1)), Xt))
  17.  
  18. # train
  19. if reg_fact == 0:
  20. w = np.linalg.lstsq(X, y)[0]
  21. else:
  22. I = np.eye(X.shape[1])
  23. w = np.linalg.inv(np.dot(X.T, X) + reg_fact*I).dot(X.T).dot(y)
  24.  
  25. # compute predictions
  26. max_label = y.max()
  27. y_train = np.clip(np.dot(X, w).round(), 0., max_label).astype(int)
  28. y_test = np.clip(np.dot(Xt, w).round(), 0., max_label).astype(int)
  29.  
  30. # compute errors
  31. train_err = (y_train != y).sum() / len(y)
  32. test_err = (y_test != yt).sum() / len(yt)
  33.  
  34. if PRINTSTUFF:
  35. print("X")
  36. print(X)
  37. print(X.shape)
  38.  
  39. print("Xt")
  40. print(Xt)
  41. print(Xt.shape)
  42.  
  43. print("reg_fact")
  44. print(reg_fact)
  45.  
  46. print("w")
  47. print(w)
  48. print(w.shape)
  49.  
  50. print("y_train")
  51. print(y_train)
  52. print(y_train.shape)
  53.  
  54. print("y_test")
  55. print(y_test)
  56. print(y_test.shape)
  57.  
  58. print("train_err")
  59. print(train_err)
  60.  
  61. print("test_err")
  62. print(test_err)
  63.  
  64. return train_err, test_err
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement