Advertisement
Guest User

Untitled

a guest
Jun 20th, 2018
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.50 KB | None | 0 0
  1. def train_readout(X, y, Xt, yt, PRINTSTUFF, reg_fact=0):
  2. """
  3. Train and test a readout. Assumes labels are integers (0., ..., N_classes-1).
  4.  
  5. :param X: training data
  6. :param y: training labels
  7. :param Xt: test data
  8. :param yt: test labels
  9. :param PRINTSTUFF: print or not
  10. :param reg_fact: regularization constant, default: 0
  11. :returns: tuple of training and test error
  12. """
  13.  
  14. # add constant component to states for bias
  15. X = np.hstack((np.ones((len(X), 1)), X))
  16. Xt = np.hstack((np.ones((len(Xt), 1)), Xt))
  17.  
  18. # train
  19. if reg_fact == 0:
  20. w = np.linalg.lstsq(X, y)[0]
  21. else:
  22. I = np.eye(X.shape[1])
  23. w = np.linalg.inv(np.dot(X.T, X) + reg_fact*I).dot(X.T).dot(y)
  24.  
  25. # compute predictions
  26. max_label = y.max()
  27. y_train = np.clip(np.dot(X, w).round(), 0., max_label).astype(int)
  28. y_test = np.clip(np.dot(Xt, w).round(), 0., max_label).astype(int)
  29.  
  30. # compute errors
  31. train_err = (y_train != y).sum() / len(y)
  32. test_err = (y_test != yt).sum() / len(yt)
  33.  
  34. if PRINTSTUFF:
  35. print("X")
  36. print(X.shape)
  37.  
  38. print("Xt")
  39. print(Xt.shape)
  40.  
  41. print("reg_fact")
  42. print(reg_fact)
  43.  
  44. print("w")
  45. print(w.shape)
  46.  
  47. print("y_train")
  48. print(y_train.shape)
  49.  
  50. print("y_test")
  51. print(y_test.shape)
  52.  
  53. print("train_err")
  54. print(train_err)
  55.  
  56. print("test_err")
  57. print(test_err)
  58.  
  59. return train_err, test_err
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement